Graph classification

Some dataset comes with intrinsic values that can be used as filtrations, e.g., graphs, medical images, molecules. Multiparameter persistence is then very well suited to deal will all of this information at once. To bring this to light, we will consider the BZR graph dataset. It can be found here.

import multipers.data.graphs as mdg
import multipers.ml.signed_measures as mms
import networkx as nx
from random import choice
from os.path import expanduser
dataset = "graphs/BZR"
path = mdg.DATASET_PATH+dataset
!ls $path ## We assume that the dataset is in this folder. You can modify the variable `mdg.DATASET_PATH` if necessary
BZR.edges      BZR.graph_labels  BZR.node_labels  graphs.pkl  readme.html
BZR.graph_idx  BZR.node_attrs	 BZR.readme	  labels.pkl
graphs, labels = mdg.get_graphs(dataset)
nx.draw(choice(graphs))
../_images/aba775975802724171127174c3d38067947ccffd093b2993ceffb75ac4dc5a94.png

Graph dataset can be filtered by several filtration : node degrees, intrinsic values, ricci curvature, closeness centrality, heat kernel signature, etc.

## uncomment this line to compute filtrations on the graphs
# mdg.compute_filtration(dataset, filtration="ALL")
graphs, labels = mdg.get_graphs(dataset) # Retrieves these filtrations
g = graphs[0] # First graph of the dataset
g.nodes[0] # First node of the dataset, which holds several filtrations
{'intrinsic': array([-2.626347,  2.492403,  0.061623]),
 'geodesic': 0,
 'cc': 0.2116788321167883,
 'degree': 0.3333333333333333,
 'ricciCurvature': -0.33333333333333326,
 'fiedler': 0.03371257377270488,
 'hks_10': 0.13721841042259394,
 'hks_1': 0.47648745932718833}

Similarly to the point clouds, we can create simplextrees, and turn them into signed measures

simplextrees = mdg.Graph2SimplexTrees(filtrations=["hks_10","degree","geodesic", "cc"]).fit_transform(graphs)
signed_measures = mms.SimplexTrees2SignedMeasures(
	degrees=[None], n_jobs=1, grid_strategy='exact', enforce_null_mass=True
).fit_transform(simplextrees) # None correspond to the euler characteristic, which is significantly faster to compute on graphs. 
# One may want to rescale filtrations w.r.t. each other. This can be done using the SignedMeasureFormatter class
signed_measures = mms.SignedMeasureFormatter(normalize=True, axis=0).fit_transform(signed_measures)

And finally classify these graphs using either a sliced wasserstein kernel, or a convolution.

from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from multipers.ml.kernels import DistanceMatrix2Kernel
## Split the data into train test
xtrain,xtest,ytrain,ytest = train_test_split(signed_measures, labels)
## Classification pipeline using the sliced wasserstein kernel
classifier = Pipeline([
    ("SWD",mms.SignedMeasure2SlicedWassersteinDistance(n_jobs=-1, num_directions=50)),
    ("KERNEL", DistanceMatrix2Kernel(sigma=1.)),
    ("SVM", SVC(kernel="precomputed")),
])
## Evaluates the classifier on this dataset. 
# Note that there is no cross validation here, so results can be significantly improved
classifier.fit(xtrain,ytrain).score(xtest,ytest)
0.9117647058823529