from generate_datasets import make_point_clouds
import matplotlib.pyplot as plt
n_samples_per_class = 1
point_clouds, labels = make_point_clouds(n_samples_per_class, 15, 0.1)
point_clouds.shape
print(f"There are {point_clouds.shape[0]} point clouds in {point_clouds.shape[2]} dimensions, "
f"each with {point_clouds.shape[1]} points.")
There are 3 point clouds in 3 dimensions, each with 225 points.
point_clouds.shape
(3, 225, 3)
def plot_in_axis(ax, point_cloud):
xs, ys, zs = zip(*point_cloud)
ax.scatter(list(xs),list(ys),list(zs))
return ax
def plot_3d_data(point_cloud):
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax = plot_in_axis(ax)
plt.show()
def plot_3_figures(point_cloud_1, point_cloud_2, point_cloud_3):
# Set up a figure three times as wide as it is tall
fig = plt.figure(figsize=plt.figaspect(0.3))
point_clouds = [point_cloud_1, point_cloud_2, point_cloud_3]
for i in range(len(point_clouds)):
axis = fig.add_subplot(1, 3, i+1, projection='3d')
plot_in_axis(axis, point_clouds[i])
plt.show()
Let's take one point cloud for each of the three categories
circumference = point_clouds[0]
sphere = point_clouds[n_samples_per_class]
torus = point_clouds[n_samples_per_class*2]
plot_3_figures(circumference, sphere, torus)
We will compute persistence diagrams for the Vietoris Rips and for the Cech filtrations.
Vietoris-Rips filtration
from gtda.homology import VietorisRipsPersistence
VR = VietorisRipsPersistence(homology_dimensions=[0, 1, 2])
diagrams_VR = VR.fit_transform(point_clouds)
diagrams_VR.shape
(3, 402, 3)
Cech filtration
import numpy as np
def filter_Cech_diagram(CH_diagram):
cleaned_diagram = []
for point in CH_diagram[0]:
if not np.isclose(point[0], point[1], atol=.01):
cleaned_diagram.append([point[0], point[1], point[2]])
return np.array([cleaned_diagram])
from gtda.homology import EuclideanCechPersistence
CH = EuclideanCechPersistence(homology_dimensions=[0, 1])
diagrams_CH = CH.fit_transform([point_clouds[n_samples_per_class*2]]) # Currently it only works with a list of ONE diagram
# We need to remove points from the diagonal: the algorithm returns points in the diagonal needed to compute
# persistence diagrams
diagrams_CH = list(map(lambda point_cloud:
filter_Cech_diagram(CH.fit_transform([point_clouds[n_samples_per_class*2]])),
point_clouds))
diagrams_CH = np.vstack(diagrams_CH)
diagrams_CH.shape
(3, 952, 3)
Now let's plot the diagrams for the Cech and the Vietoris-Rips filtration. We can do it either with the plot_diagram
function as well as with the plot of the VR method.
from gtda.plotting import plot_diagram
plot_diagram(diagrams_VR[0])
or
VR.plot(diagrams_VR, sample=0)
Now let's plot both Cech and Vietoris-Rips persistence diagrams.
import chart_studio.plotly as py
from ipywidgets import widgets
import plotly.graph_objects as go
def plot_VR_and_Cech(ppdd_VR, ppdd_Cech):
fig_VR = plot_diagram(ppdd_VR)
fig_VR.update_layout(title="Vietoris-Rips persistence diagram")
fig_Cech = plot_diagram(ppdd_Cech)
fig_Cech.update_layout(title="Cech persistence diagram")
fig_both = widgets.HBox(children=[go.FigureWidget(fig_VR), go.FigureWidget(fig_Cech)])
return fig_both
point_cloud_torus_id = n_samples_per_class*2
plot_VR_and_Cech(diagrams_VR[point_cloud_torus_id], diagrams_CH[point_cloud_torus_id]) #diagrams from sampled torus
Now let's compute some features from our persistence diagram.
from gtda.diagrams import PersistenceEntropy, PersistenceImage
PE = PersistenceEntropy()
features = PE.fit_transform(diagrams_VR)
features[0][1]
array([0. , 0. , 0. , 0. , 0.00758692, 0.02731114, 0.04703537, 0.06675959, 0.08648382, 0.10620804, 0.12593226, 0.14565649, 0.16538071, 0.18510494, 0.20482916, 0.22455338, 0.24427761, 0.26400183, 0.28372606, 0.30345028, 0.3231745 , 0.34289873, 0.36262295, 0.38234718, 0.4020714 , 0.42179562, 0.44151985, 0.46124407, 0.4809683 , 0.50069252, 0.52041674, 0.54014097, 0.55986519, 0.57958942, 0.59931364, 0.61903786, 0.63876209, 0.65848631, 0.67821054, 0.69793476, 0.71765898, 0.73738321, 0.75710743, 0.74068368, 0.72095945, 0.70123523, 0.68151101, 0.66178678, 0.64206256, 0.62233833, 0.60261411, 0.58288989, 0.56316566, 0.54344144, 0.52371721, 0.50399299, 0.48426877, 0.46454454, 0.44482032, 0.42509609, 0.40537187, 0.38564765, 0.36592342, 0.3461992 , 0.32647497, 0.30675075, 0.28702653, 0.2673023 , 0.24757808, 0.22785385, 0.20812963, 0.18840541, 0.16868118, 0.14895696, 0.12923273, 0.10950851, 0.08978429, 0.07006006, 0.05033584, 0.03061162, 0.01088739, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ])
import matplotlib.image as mpimg
PI = PersistenceImage()
features = PI.fit_transform(diagrams_VR)
plt.imshow(features[0][1])
<matplotlib.image.AxesImage at 0x197a0e66070>
from gtda.diagrams import Silhouette
PS = Silhouette()
features = PS.fit_transform_plot(diagrams_VR)
features
array([[[0.00000000e+00, 7.39900115e-03, 1.37935752e-02, 1.79439340e-02, 1.86358909e-02, 1.64152840e-02, 1.25623199e-02, 8.57831872e-03, 5.43301115e-03, 3.15642730e-03, 1.57422868e-03, 5.80019867e-04, 2.54822622e-04, 9.57567790e-05, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.58692007e-03, 2.73111441e-02, 4.70353680e-02, 6.67595920e-02, 8.64838160e-02, 1.06208040e-01, 1.25932264e-01, 1.45656488e-01, 1.65380712e-01, 1.85104936e-01, 2.04829160e-01, 2.24553384e-01, 2.44277608e-01, 2.64001832e-01, 2.83726056e-01, 3.03450280e-01, 3.23174504e-01, 3.42898728e-01, 3.62622952e-01, 3.82347176e-01, 4.02071400e-01, 4.21795624e-01, 4.41519848e-01, 4.61244072e-01, 4.80968296e-01, 5.00692520e-01, 5.20416744e-01, 5.40140968e-01, 5.59865192e-01, 5.79589416e-01, 5.99313640e-01, 6.19037864e-01, 6.38762088e-01, 6.58486312e-01, 6.78210536e-01, 6.97934760e-01, 7.17658984e-01, 7.37383208e-01, 7.57107432e-01, 7.40683679e-01, 7.20959455e-01, 7.01235231e-01, 6.81511007e-01, 6.61786783e-01, 6.42062559e-01, 6.22338335e-01, 6.02614111e-01, 5.82889887e-01, 5.63165663e-01, 5.43441439e-01, 5.23717215e-01, 5.03992991e-01, 4.84268767e-01, 4.64544543e-01, 4.44820319e-01, 4.25096095e-01, 4.05371871e-01, 3.85647647e-01, 3.65923423e-01, 3.46199199e-01, 3.26474975e-01, 3.06750751e-01, 2.87026527e-01, 2.67302303e-01, 2.47578079e-01, 2.27853855e-01, 2.08129631e-01, 1.88405407e-01, 1.68681183e-01, 1.48956959e-01, 1.29232735e-01, 1.09508511e-01, 8.97842870e-02, 7.00600630e-02, 5.03358390e-02, 3.06116150e-02, 1.08873911e-02, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.77615735e-03, 6.07607452e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]])