Graph classification
Some dataset comes with intrinsic values that can be used as filtrations, e.g., graphs, medical images, molecules. Multiparameter persistence is then very well suited to deal will all of this information at once. To bring this to light, we will consider the BZR graph dataset. It can be found here.
import multipers.data.graphs as mdg
import multipers.ml.signed_measures as mms
import networkx as nx
from random import choice
from os.path import expanduser
dataset = "graphs/BZR"
path = mdg.DATASET_PATH+dataset
!ls $path ## We assume that the dataset is in this folder. You can modify the variable `mdg.DATASET_PATH` if necessary
BZR.edges BZR.graph_labels BZR.node_labels graphs.pkl readme.html
BZR.graph_idx BZR.node_attrs BZR.readme labels.pkl
graphs, labels = mdg.get_graphs(dataset)
nx.draw(choice(graphs))
![../_images/aba775975802724171127174c3d38067947ccffd093b2993ceffb75ac4dc5a94.png](../_images/aba775975802724171127174c3d38067947ccffd093b2993ceffb75ac4dc5a94.png)
Graph dataset can be filtered by several filtration : node degrees, intrinsic values, ricci curvature, closeness centrality, heat kernel signature, etc.
## uncomment this line to compute filtrations on the graphs
# mdg.compute_filtration(dataset, filtration="ALL")
graphs, labels = mdg.get_graphs(dataset) # Retrieves these filtrations
g = graphs[0] # First graph of the dataset
g.nodes[0] # First node of the dataset, which holds several filtrations
{'intrinsic': array([-2.626347, 2.492403, 0.061623]),
'geodesic': 0,
'cc': 0.2116788321167883,
'degree': 0.3333333333333333,
'ricciCurvature': -0.33333333333333326,
'fiedler': 0.03371257377270488,
'hks_10': 0.13721841042259394,
'hks_1': 0.47648745932718833}
Similarly to the point clouds, we can create simplextrees, and turn them into signed measures
simplextrees = mdg.Graph2SimplexTrees(filtrations=["hks_10","degree","geodesic", "cc"]).fit_transform(graphs)
signed_measures = mms.SimplexTrees2SignedMeasures(
degrees=[None], n_jobs=1, grid_strategy='exact', enforce_null_mass=True
).fit_transform(simplextrees) # None correspond to the euler characteristic, which is significantly faster to compute on graphs.
# One may want to rescale filtrations w.r.t. each other. This can be done using the SignedMeasureFormatter class
signed_measures = mms.SignedMeasureFormatter(normalize=True, axis=0).fit_transform(signed_measures)
And finally classify these graphs using either a sliced wasserstein kernel, or a convolution.
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from multipers.ml.kernels import DistanceMatrix2Kernel
## Split the data into train test
xtrain,xtest,ytrain,ytest = train_test_split(signed_measures, labels)
## Classification pipeline using the sliced wasserstein kernel
classifier = Pipeline([
("SWD",mms.SignedMeasure2SlicedWassersteinDistance(n_jobs=-1, num_directions=50)),
("KERNEL", DistanceMatrix2Kernel(sigma=1.)),
("SVM", SVC(kernel="precomputed")),
])
## Evaluates the classifier on this dataset.
# Note that there is no cross validation here, so results can be significantly improved
classifier.fit(xtrain,ytrain).score(xtest,ytest)
0.9117647058823529