{ "cells": [ { "cell_type": "markdown", "id": "d3b991b0", "metadata": {}, "source": [ "# TD3: Statistics, Unsupervised, and Supervised Machine Learning on 3D shapes with Topological Data Analysis" ] }, { "cell_type": "markdown", "id": "1791e9c1", "metadata": {}, "source": [ "In this practical session, we will use the various TDA tools presented in class in order to run data science tasks (inference, clustering, classification) on a data set of 3D shapes. As in the first practical session, we will use [`Gudhi`](https://gudhi.inria.fr/) (see first practical session for installation instructions). The different sections of this notebook can be run independently (except Section 0 which is mandatory), so feel free to start with the project that sounds the more interesting to you :-)" ] }, { "cell_type": "markdown", "id": "414af03f", "metadata": {}, "source": [ "Note also that if you choose to switch from a section to another, make sure to clear all variables first (and run Section 0 again) since some variable names are shared between sections." ] }, { "cell_type": "code", "execution_count": 1, "id": "c130ab56", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.8.0a0\n" ] } ], "source": [ "import gudhi as gd\n", "print(gd.__version__)" ] }, { "cell_type": "markdown", "id": "f907d3c0", "metadata": {}, "source": [ "Other than that, you are free to use whatever other Python package you feel comfortable with :-) We make some suggestions below (these dependencies are also required to run our solutions to the exercises). " ] }, { "cell_type": "code", "execution_count": 2, "id": "cf40a838", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys" ] }, { "cell_type": "markdown", "id": "348293aa", "metadata": {}, "source": [ "We will use three standard Python libraries: `NumPy`, `Scipy` and `Matplotlib`." ] }, { "cell_type": "code", "execution_count": 3, "id": "bea0919a", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import scipy as sp\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 4, "id": "83a1ac92", "metadata": {}, "outputs": [], "source": [ "%matplotlib notebook" ] }, { "cell_type": "markdown", "id": "c15ff505", "metadata": {}, "source": [ "In order to visualize 3D shapes, we will use [`meshplot`](https://skoch9.github.io/meshplot/tutorial/)." ] }, { "cell_type": "code", "execution_count": 5, "id": "fea43227", "metadata": {}, "outputs": [], "source": [ "import meshplot as mp" ] }, { "cell_type": "markdown", "id": "c389c8fb", "metadata": {}, "source": [ "When computing vectorizations and performing supervised machine learning and deep learning tasks, we will use various modules of [`Scikit-Learn`](https://scikit-learn.org/stable/index.html). " ] }, { "cell_type": "code", "execution_count": 6, "id": "7a9bbaa9", "metadata": {}, "outputs": [], "source": [ "import sklearn.preprocessing as skp\n", "import sklearn.neighbors as skn\n", "import sklearn.model_selection as skm\n", "import sklearn.decomposition as skd\n", "import sklearn.manifold as skf\n", "import sklearn.pipeline as skl\n", "import sklearn.svm as sks\n", "import sklearn.ensemble as ske" ] }, { "cell_type": "markdown", "id": "607645e0", "metadata": {}, "source": [ "# Section 0: Data set manipulation" ] }, { "cell_type": "markdown", "id": "a6c8dd35", "metadata": {}, "source": [ "We are good to go! First things first, we have to download the data set. It can be obtained [here](https://people.cs.umass.edu/~kalo/papers/LabelMeshes/labeledDb.7z). Extract it, and save its path in the `dataset_path` variable." ] }, { "cell_type": "code", "execution_count": 7, "id": "9ed33357", "metadata": {}, "outputs": [], "source": [ "dataset_path = './3dshapes/'" ] }, { "cell_type": "markdown", "id": "63bf89b6", "metadata": {}, "source": [ "As you can see, the data set in split in several categories (`Airplane`, `Human`, `Teddy`, etc), each category having its own folder. Inside each folder, some 3D shapes (i.e., 3D triangulations) are provided in [`.off`](https://en.wikipedia.org/wiki/OFF_(file_format)) format, and face (i.e., triangle) labels are provided in text files (extension `.txt`). " ] }, { "cell_type": "markdown", "id": "78cf080e", "metadata": {}, "source": [ "Every data science project begins by some preprocessing ;-) " ] }, { "cell_type": "markdown", "id": "3b52c414", "metadata": {}, "source": [ "Write a function `off2numpy` that reads information from an `.off` file and store it in two `NumPy` arrays, called `vertices` (type float and shape number_of_vertices x 3---the 3D coordinates of the vertices) and `faces` (type integer and shape number_of_faces x 3---the IDs of the vertices that create faces). Write also a function `get_labels` that stores the face labels of a given 3D shape in a `NumPy` array (type string or integer and shape [number_of_faces]. " ] }, { "cell_type": "code", "execution_count": 8, "id": "c2d26e0a", "metadata": {}, "outputs": [], "source": [ "def off2numpy(shape_name):\n", " with open(shape_name, 'r') as S:\n", " S.readline()\n", " num_vertices, num_faces, _ = [int(n) for n in S.readline().split(' ')]\n", " info = S.readlines()\n", " vertices = np.array([[float(coord) for coord in l.split(' ')] for l in info[0:num_vertices]])\n", " faces = np.array([[int(coord) for coord in l.split(' ')[1:]] for l in info[num_vertices:]])\n", " return vertices, faces" ] }, { "cell_type": "code", "execution_count": 9, "id": "7ce122d7", "metadata": {}, "outputs": [], "source": [ "def get_labels(label_name, num_faces):\n", " L = np.empty([num_faces], dtype='|S100')\n", " with open(label_name, 'r') as S:\n", " info = S.readlines()\n", " labels, face_indices = info[0::2], info[1::2]\n", " for ilab, lab in enumerate(labels):\n", " indices = [int(f)-1 for f in face_indices[ilab].split(' ')[:-1]]\n", " L[ np.array(indices) ] = lab[:-1]\n", " return L" ] }, { "cell_type": "markdown", "id": "1b6db7fa", "metadata": {}, "source": [ "You can now apply your code and use `meshplot` to visualize a given 3D shape, say `61.off` in `Airplane`, and the labels on its faces." ] }, { "cell_type": "code", "execution_count": 10, "id": "07117f79", "metadata": {}, "outputs": [], "source": [ "vertices, faces = off2numpy(dataset_path + 'Airplane/61.off')\n", "label_faces = get_labels(dataset_path + 'Airplane/61_labels.txt', len(faces))" ] }, { "cell_type": "code", "execution_count": 11, "id": "0c8395e3", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b3aa1b74cd654436a0d9978175f447c4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0636740…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mp.plot(vertices, faces, c=skp.LabelEncoder().fit_transform(label_faces))" ] }, { "cell_type": "markdown", "id": "ae5210fe", "metadata": {}, "source": [ "If `meshplot` does not work, we also provide a fix with `matplotlib` (it requires converting the face labels into point labels though)." ] }, { "cell_type": "code", "execution_count": 12, "id": "c449241b", "metadata": {}, "outputs": [], "source": [ "def face2points(vals_faces, faces, num_vertices):\n", " vals_points = np.empty([num_vertices], dtype=type(vals_faces))\n", " for iface, face in enumerate(faces):\n", " vals_points[face] = vals_faces[iface]\n", " return vals_points" ] }, { "cell_type": "code", "execution_count": 13, "id": "82718067", "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "/* global mpl */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function () {\n", " if (typeof WebSocket !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof MozWebSocket !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert(\n", " 'Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.'\n", " );\n", " }\n", "};\n", "\n", "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = this.ws.binaryType !== undefined;\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById('mpl-warnings');\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent =\n", " 'This browser does not support binary websocket messages. ' +\n", " 'Performance may be slow.';\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = document.createElement('div');\n", " this.root.setAttribute('style', 'display: inline-block');\n", " this._root_extra_style(this.root);\n", "\n", " parent_element.appendChild(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message('supports_binary', { value: fig.supports_binary });\n", " fig.send_message('send_image_mode', {});\n", " if (fig.ratio !== 1) {\n", " fig.send_message('set_device_pixel_ratio', {\n", " device_pixel_ratio: fig.ratio,\n", " });\n", " }\n", " fig.send_message('refresh', {});\n", " };\n", "\n", " this.imageObj.onload = function () {\n", " if (fig.image_mode === 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function () {\n", " fig.ws.close();\n", " };\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "};\n", "\n", "mpl.figure.prototype._init_header = function () {\n", " var titlebar = document.createElement('div');\n", " titlebar.classList =\n", " 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n", " var titletext = document.createElement('div');\n", " titletext.classList = 'ui-dialog-title';\n", " titletext.setAttribute(\n", " 'style',\n", " 'width: 100%; text-align: center; padding: 3px;'\n", " );\n", " titlebar.appendChild(titletext);\n", " this.root.appendChild(titlebar);\n", " this.header = titletext;\n", "};\n", "\n", "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n", "\n", "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n", "\n", "mpl.figure.prototype._init_canvas = function () {\n", " var fig = this;\n", "\n", " var canvas_div = (this.canvas_div = document.createElement('div'));\n", " canvas_div.setAttribute(\n", " 'style',\n", " 'border: 1px solid #ddd;' +\n", " 'box-sizing: content-box;' +\n", " 'clear: both;' +\n", " 'min-height: 1px;' +\n", " 'min-width: 1px;' +\n", " 'outline: 0;' +\n", " 'overflow: hidden;' +\n", " 'position: relative;' +\n", " 'resize: both;'\n", " );\n", "\n", " function on_keyboard_event_closure(name) {\n", " return function (event) {\n", " return fig.key_event(event, name);\n", " };\n", " }\n", "\n", " canvas_div.addEventListener(\n", " 'keydown',\n", " on_keyboard_event_closure('key_press')\n", " );\n", " canvas_div.addEventListener(\n", " 'keyup',\n", " on_keyboard_event_closure('key_release')\n", " );\n", "\n", " this._canvas_extra_style(canvas_div);\n", " this.root.appendChild(canvas_div);\n", "\n", " var canvas = (this.canvas = document.createElement('canvas'));\n", " canvas.classList.add('mpl-canvas');\n", " canvas.setAttribute('style', 'box-sizing: content-box;');\n", "\n", " this.context = canvas.getContext('2d');\n", "\n", " var backingStore =\n", " this.context.backingStorePixelRatio ||\n", " this.context.webkitBackingStorePixelRatio ||\n", " this.context.mozBackingStorePixelRatio ||\n", " this.context.msBackingStorePixelRatio ||\n", " this.context.oBackingStorePixelRatio ||\n", " this.context.backingStorePixelRatio ||\n", " 1;\n", "\n", " this.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n", " 'canvas'\n", " ));\n", " rubberband_canvas.setAttribute(\n", " 'style',\n", " 'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n", " );\n", "\n", " // Apply a ponyfill if ResizeObserver is not implemented by browser.\n", " if (this.ResizeObserver === undefined) {\n", " if (window.ResizeObserver !== undefined) {\n", " this.ResizeObserver = window.ResizeObserver;\n", " } else {\n", " var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n", " this.ResizeObserver = obs.ResizeObserver;\n", " }\n", " }\n", "\n", " this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n", " var nentries = entries.length;\n", " for (var i = 0; i < nentries; i++) {\n", " var entry = entries[i];\n", " var width, height;\n", " if (entry.contentBoxSize) {\n", " if (entry.contentBoxSize instanceof Array) {\n", " // Chrome 84 implements new version of spec.\n", " width = entry.contentBoxSize[0].inlineSize;\n", " height = entry.contentBoxSize[0].blockSize;\n", " } else {\n", " // Firefox implements old version of spec.\n", " width = entry.contentBoxSize.inlineSize;\n", " height = entry.contentBoxSize.blockSize;\n", " }\n", " } else {\n", " // Chrome <84 implements even older version of spec.\n", " width = entry.contentRect.width;\n", " height = entry.contentRect.height;\n", " }\n", "\n", " // Keep the size of the canvas and rubber band canvas in sync with\n", " // the canvas container.\n", " if (entry.devicePixelContentBoxSize) {\n", " // Chrome 84 implements new version of spec.\n", " canvas.setAttribute(\n", " 'width',\n", " entry.devicePixelContentBoxSize[0].inlineSize\n", " );\n", " canvas.setAttribute(\n", " 'height',\n", " entry.devicePixelContentBoxSize[0].blockSize\n", " );\n", " } else {\n", " canvas.setAttribute('width', width * fig.ratio);\n", " canvas.setAttribute('height', height * fig.ratio);\n", " }\n", " canvas.setAttribute(\n", " 'style',\n", " 'width: ' + width + 'px; height: ' + height + 'px;'\n", " );\n", "\n", " rubberband_canvas.setAttribute('width', width);\n", " rubberband_canvas.setAttribute('height', height);\n", "\n", " // And update the size in Python. We ignore the initial 0/0 size\n", " // that occurs as the element is placed into the DOM, which should\n", " // otherwise not happen due to the minimum size styling.\n", " if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n", " fig.request_resize(width, height);\n", " }\n", " }\n", " });\n", " this.resizeObserverInstance.observe(canvas_div);\n", "\n", " function on_mouse_event_closure(name) {\n", " return function (event) {\n", " return fig.mouse_event(event, name);\n", " };\n", " }\n", "\n", " rubberband_canvas.addEventListener(\n", " 'mousedown',\n", " on_mouse_event_closure('button_press')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'mouseup',\n", " on_mouse_event_closure('button_release')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'dblclick',\n", " on_mouse_event_closure('dblclick')\n", " );\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband_canvas.addEventListener(\n", " 'mousemove',\n", " on_mouse_event_closure('motion_notify')\n", " );\n", "\n", " rubberband_canvas.addEventListener(\n", " 'mouseenter',\n", " on_mouse_event_closure('figure_enter')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'mouseleave',\n", " on_mouse_event_closure('figure_leave')\n", " );\n", "\n", " canvas_div.addEventListener('wheel', function (event) {\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " on_mouse_event_closure('scroll')(event);\n", " });\n", "\n", " canvas_div.appendChild(canvas);\n", " canvas_div.appendChild(rubberband_canvas);\n", "\n", " this.rubberband_context = rubberband_canvas.getContext('2d');\n", " this.rubberband_context.strokeStyle = '#000000';\n", "\n", " this._resize_canvas = function (width, height, forward) {\n", " if (forward) {\n", " canvas_div.style.width = width + 'px';\n", " canvas_div.style.height = height + 'px';\n", " }\n", " };\n", "\n", " // Disable right mouse context menu.\n", " this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n", " event.preventDefault();\n", " return false;\n", " });\n", "\n", " function set_focus() {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "};\n", "\n", "mpl.figure.prototype._init_toolbar = function () {\n", " var fig = this;\n", "\n", " var toolbar = document.createElement('div');\n", " toolbar.classList = 'mpl-toolbar';\n", " this.root.appendChild(toolbar);\n", "\n", " function on_click_closure(name) {\n", " return function (_event) {\n", " return fig.toolbar_button_onclick(name);\n", " };\n", " }\n", "\n", " function on_mouseover_closure(tooltip) {\n", " return function (event) {\n", " if (!event.currentTarget.disabled) {\n", " return fig.toolbar_button_onmouseover(tooltip);\n", " }\n", " };\n", " }\n", "\n", " fig.buttons = {};\n", " var buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'mpl-button-group';\n", " for (var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " /* Instead of a spacer, we start a new button group. */\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", " buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'mpl-button-group';\n", " continue;\n", " }\n", "\n", " var button = (fig.buttons[name] = document.createElement('button'));\n", " button.classList = 'mpl-widget';\n", " button.setAttribute('role', 'button');\n", " button.setAttribute('aria-disabled', 'false');\n", " button.addEventListener('click', on_click_closure(method_name));\n", " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", "\n", " var icon_img = document.createElement('img');\n", " icon_img.src = '_images/' + image + '.png';\n", " icon_img.srcset = '_images/' + image + '_large.png 2x';\n", " icon_img.alt = tooltip;\n", " button.appendChild(icon_img);\n", "\n", " buttonGroup.appendChild(button);\n", " }\n", "\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", "\n", " var fmt_picker = document.createElement('select');\n", " fmt_picker.classList = 'mpl-widget';\n", " toolbar.appendChild(fmt_picker);\n", " this.format_dropdown = fmt_picker;\n", "\n", " for (var ind in mpl.extensions) {\n", " var fmt = mpl.extensions[ind];\n", " var option = document.createElement('option');\n", " option.selected = fmt === mpl.default_extension;\n", " option.innerHTML = fmt;\n", " fmt_picker.appendChild(option);\n", " }\n", "\n", " var status_bar = document.createElement('span');\n", " status_bar.classList = 'mpl-message';\n", " toolbar.appendChild(status_bar);\n", " this.message = status_bar;\n", "};\n", "\n", "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n", " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n", " // which will in turn request a refresh of the image.\n", " this.send_message('resize', { width: x_pixels, height: y_pixels });\n", "};\n", "\n", "mpl.figure.prototype.send_message = function (type, properties) {\n", " properties['type'] = type;\n", " properties['figure_id'] = this.id;\n", " this.ws.send(JSON.stringify(properties));\n", "};\n", "\n", "mpl.figure.prototype.send_draw_message = function () {\n", " if (!this.waiting) {\n", " this.waiting = true;\n", " this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", " var format_dropdown = fig.format_dropdown;\n", " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n", " fig.ondownload(fig, format);\n", "};\n", "\n", "mpl.figure.prototype.handle_resize = function (fig, msg) {\n", " var size = msg['size'];\n", " if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n", " fig._resize_canvas(size[0], size[1], msg['forward']);\n", " fig.send_message('refresh', {});\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n", " var x0 = msg['x0'] / fig.ratio;\n", " var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n", " var x1 = msg['x1'] / fig.ratio;\n", " var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n", " x0 = Math.floor(x0) + 0.5;\n", " y0 = Math.floor(y0) + 0.5;\n", " x1 = Math.floor(x1) + 0.5;\n", " y1 = Math.floor(y1) + 0.5;\n", " var min_x = Math.min(x0, x1);\n", " var min_y = Math.min(y0, y1);\n", " var width = Math.abs(x1 - x0);\n", " var height = Math.abs(y1 - y0);\n", "\n", " fig.rubberband_context.clearRect(\n", " 0,\n", " 0,\n", " fig.canvas.width / fig.ratio,\n", " fig.canvas.height / fig.ratio\n", " );\n", "\n", " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n", "};\n", "\n", "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n", " // Updates the figure title.\n", " fig.header.textContent = msg['label'];\n", "};\n", "\n", "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n", " fig.rubberband_canvas.style.cursor = msg['cursor'];\n", "};\n", "\n", "mpl.figure.prototype.handle_message = function (fig, msg) {\n", " fig.message.textContent = msg['message'];\n", "};\n", "\n", "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n", " // Request the server to send over a new figure.\n", " fig.send_draw_message();\n", "};\n", "\n", "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n", " fig.image_mode = msg['mode'];\n", "};\n", "\n", "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n", " for (var key in msg) {\n", " if (!(key in fig.buttons)) {\n", " continue;\n", " }\n", " fig.buttons[key].disabled = !msg[key];\n", " fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n", " if (msg['mode'] === 'PAN') {\n", " fig.buttons['Pan'].classList.add('active');\n", " fig.buttons['Zoom'].classList.remove('active');\n", " } else if (msg['mode'] === 'ZOOM') {\n", " fig.buttons['Pan'].classList.remove('active');\n", " fig.buttons['Zoom'].classList.add('active');\n", " } else {\n", " fig.buttons['Pan'].classList.remove('active');\n", " fig.buttons['Zoom'].classList.remove('active');\n", " }\n", "};\n", "\n", "mpl.figure.prototype.updated_canvas_event = function () {\n", " // Called whenever the canvas gets updated.\n", " this.send_message('ack', {});\n", "};\n", "\n", "// A function to construct a web socket function for onmessage handling.\n", "// Called in the figure constructor.\n", "mpl.figure.prototype._make_on_message_function = function (fig) {\n", " return function socket_on_message(evt) {\n", " if (evt.data instanceof Blob) {\n", " var img = evt.data;\n", " if (img.type !== 'image/png') {\n", " /* FIXME: We get \"Resource interpreted as Image but\n", " * transferred with MIME type text/plain:\" errors on\n", " * Chrome. But how to set the MIME type? It doesn't seem\n", " * to be part of the websocket stream */\n", " img.type = 'image/png';\n", " }\n", "\n", " /* Free the memory for the previous frames */\n", " if (fig.imageObj.src) {\n", " (window.URL || window.webkitURL).revokeObjectURL(\n", " fig.imageObj.src\n", " );\n", " }\n", "\n", " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n", " img\n", " );\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " } else if (\n", " typeof evt.data === 'string' &&\n", " evt.data.slice(0, 21) === 'data:image/png;base64'\n", " ) {\n", " fig.imageObj.src = evt.data;\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " }\n", "\n", " var msg = JSON.parse(evt.data);\n", " var msg_type = msg['type'];\n", "\n", " // Call the \"handle_{type}\" callback, which takes\n", " // the figure and JSON message as its only arguments.\n", " try {\n", " var callback = fig['handle_' + msg_type];\n", " } catch (e) {\n", " console.log(\n", " \"No handler for the '\" + msg_type + \"' message type: \",\n", " msg\n", " );\n", " return;\n", " }\n", "\n", " if (callback) {\n", " try {\n", " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n", " callback(fig, msg);\n", " } catch (e) {\n", " console.log(\n", " \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n", " e,\n", " e.stack,\n", " msg\n", " );\n", " }\n", " }\n", " };\n", "};\n", "\n", "// from https://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n", "mpl.findpos = function (e) {\n", " //this section is from http://www.quirksmode.org/js/events_properties.html\n", " var targ;\n", " if (!e) {\n", " e = window.event;\n", " }\n", " if (e.target) {\n", " targ = e.target;\n", " } else if (e.srcElement) {\n", " targ = e.srcElement;\n", " }\n", " if (targ.nodeType === 3) {\n", " // defeat Safari bug\n", " targ = targ.parentNode;\n", " }\n", "\n", " // pageX,Y are the mouse positions relative to the document\n", " var boundingRect = targ.getBoundingClientRect();\n", " var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n", " var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n", "\n", " return { x: x, y: y };\n", "};\n", "\n", "/*\n", " * return a copy of an object with only non-object keys\n", " * we need this to avoid circular references\n", " * https://stackoverflow.com/a/24161582/3208463\n", " */\n", "function simpleKeys(original) {\n", " return Object.keys(original).reduce(function (obj, key) {\n", " if (typeof original[key] !== 'object') {\n", " obj[key] = original[key];\n", " }\n", " return obj;\n", " }, {});\n", "}\n", "\n", "mpl.figure.prototype.mouse_event = function (event, name) {\n", " var canvas_pos = mpl.findpos(event);\n", "\n", " if (name === 'button_press') {\n", " this.canvas.focus();\n", " this.canvas_div.focus();\n", " }\n", "\n", " var x = canvas_pos.x * this.ratio;\n", " var y = canvas_pos.y * this.ratio;\n", "\n", " this.send_message(name, {\n", " x: x,\n", " y: y,\n", " button: event.button,\n", " step: event.step,\n", " guiEvent: simpleKeys(event),\n", " });\n", "\n", " /* This prevents the web browser from automatically changing to\n", " * the text insertion cursor when the button is pressed. We want\n", " * to control all of the cursor setting manually through the\n", " * 'cursor' event from matplotlib */\n", " event.preventDefault();\n", " return false;\n", "};\n", "\n", "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n", " // Handle any extra behaviour associated with a key event\n", "};\n", "\n", "mpl.figure.prototype.key_event = function (event, name) {\n", " // Prevent repeat events\n", " if (name === 'key_press') {\n", " if (event.key === this._key) {\n", " return;\n", " } else {\n", " this._key = event.key;\n", " }\n", " }\n", " if (name === 'key_release') {\n", " this._key = null;\n", " }\n", "\n", " var value = '';\n", " if (event.ctrlKey && event.key !== 'Control') {\n", " value += 'ctrl+';\n", " }\n", " else if (event.altKey && event.key !== 'Alt') {\n", " value += 'alt+';\n", " }\n", " else if (event.shiftKey && event.key !== 'Shift') {\n", " value += 'shift+';\n", " }\n", "\n", " value += 'k' + event.key;\n", "\n", " this._key_event_extra(event, name);\n", "\n", " this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n", " return false;\n", "};\n", "\n", "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n", " if (name === 'download') {\n", " this.handle_save(this, null);\n", " } else {\n", " this.send_message('toolbar_button', { name: name });\n", " }\n", "};\n", "\n", "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n", " this.message.textContent = tooltip;\n", "};\n", "\n", "///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n", "// prettier-ignore\n", "var _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\n", "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n", "\n", "mpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n", "\n", "mpl.default_extension = \"png\";/* global mpl */\n", "\n", "var comm_websocket_adapter = function (comm) {\n", " // Create a \"websocket\"-like object which calls the given IPython comm\n", " // object with the appropriate methods. Currently this is a non binary\n", " // socket, so there is still some room for performance tuning.\n", " var ws = {};\n", "\n", " ws.binaryType = comm.kernel.ws.binaryType;\n", " ws.readyState = comm.kernel.ws.readyState;\n", " function updateReadyState(_event) {\n", " if (comm.kernel.ws) {\n", " ws.readyState = comm.kernel.ws.readyState;\n", " } else {\n", " ws.readyState = 3; // Closed state.\n", " }\n", " }\n", " comm.kernel.ws.addEventListener('open', updateReadyState);\n", " comm.kernel.ws.addEventListener('close', updateReadyState);\n", " comm.kernel.ws.addEventListener('error', updateReadyState);\n", "\n", " ws.close = function () {\n", " comm.close();\n", " };\n", " ws.send = function (m) {\n", " //console.log('sending', m);\n", " comm.send(m);\n", " };\n", " // Register the callback with on_msg.\n", " comm.on_msg(function (msg) {\n", " //console.log('receiving', msg['content']['data'], msg);\n", " var data = msg['content']['data'];\n", " if (data['blob'] !== undefined) {\n", " data = {\n", " data: new Blob(msg['buffers'], { type: data['blob'] }),\n", " };\n", " }\n", " // Pass the mpl event to the overridden (by mpl) onmessage function.\n", " ws.onmessage(data);\n", " });\n", " return ws;\n", "};\n", "\n", "mpl.mpl_figure_comm = function (comm, msg) {\n", " // This is the function which gets called when the mpl process\n", " // starts-up an IPython Comm through the \"matplotlib\" channel.\n", "\n", " var id = msg.content.data.id;\n", " // Get hold of the div created by the display call when the Comm\n", " // socket was opened in Python.\n", " var element = document.getElementById(id);\n", " var ws_proxy = comm_websocket_adapter(comm);\n", "\n", " function ondownload(figure, _format) {\n", " window.open(figure.canvas.toDataURL());\n", " }\n", "\n", " var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n", "\n", " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n", " // web socket which is closed, not our websocket->open comm proxy.\n", " ws_proxy.onopen();\n", "\n", " fig.parent_element = element;\n", " fig.cell_info = mpl.find_output_cell(\"
\");\n", " if (!fig.cell_info) {\n", " console.error('Failed to find cell for figure', id, fig);\n", " return;\n", " }\n", " fig.cell_info[0].output_area.element.on(\n", " 'cleared',\n", " { fig: fig },\n", " fig._remove_fig_handler\n", " );\n", "};\n", "\n", "mpl.figure.prototype.handle_close = function (fig, msg) {\n", " var width = fig.canvas.width / fig.ratio;\n", " fig.cell_info[0].output_area.element.off(\n", " 'cleared',\n", " fig._remove_fig_handler\n", " );\n", " fig.resizeObserverInstance.unobserve(fig.canvas_div);\n", "\n", " // Update the output cell to use the data from the current canvas.\n", " fig.push_to_output();\n", " var dataURL = fig.canvas.toDataURL();\n", " // Re-enable the keyboard manager in IPython - without this line, in FF,\n", " // the notebook keyboard shortcuts fail.\n", " IPython.keyboard_manager.enable();\n", " fig.parent_element.innerHTML =\n", " '';\n", " fig.close_ws(fig, msg);\n", "};\n", "\n", "mpl.figure.prototype.close_ws = function (fig, msg) {\n", " fig.send_message('closing', msg);\n", " // fig.ws.close()\n", "};\n", "\n", "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n", " // Turn the data on the canvas into data in the output cell.\n", " var width = this.canvas.width / this.ratio;\n", " var dataURL = this.canvas.toDataURL();\n", " this.cell_info[1]['text/html'] =\n", " '';\n", "};\n", "\n", "mpl.figure.prototype.updated_canvas_event = function () {\n", " // Tell IPython that the notebook contents must change.\n", " IPython.notebook.set_dirty(true);\n", " this.send_message('ack', {});\n", " var fig = this;\n", " // Wait a second, then push the new image to the DOM so\n", " // that it is saved nicely (might be nice to debounce this).\n", " setTimeout(function () {\n", " fig.push_to_output();\n", " }, 1000);\n", "};\n", "\n", "mpl.figure.prototype._init_toolbar = function () {\n", " var fig = this;\n", "\n", " var toolbar = document.createElement('div');\n", " toolbar.classList = 'btn-toolbar';\n", " this.root.appendChild(toolbar);\n", "\n", " function on_click_closure(name) {\n", " return function (_event) {\n", " return fig.toolbar_button_onclick(name);\n", " };\n", " }\n", "\n", " function on_mouseover_closure(tooltip) {\n", " return function (event) {\n", " if (!event.currentTarget.disabled) {\n", " return fig.toolbar_button_onmouseover(tooltip);\n", " }\n", " };\n", " }\n", "\n", " fig.buttons = {};\n", " var buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'btn-group';\n", " var button;\n", " for (var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " /* Instead of a spacer, we start a new button group. */\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", " buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'btn-group';\n", " continue;\n", " }\n", "\n", " button = fig.buttons[name] = document.createElement('button');\n", " button.classList = 'btn btn-default';\n", " button.href = '#';\n", " button.title = name;\n", " button.innerHTML = '';\n", " button.addEventListener('click', on_click_closure(method_name));\n", " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", " buttonGroup.appendChild(button);\n", " }\n", "\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = document.createElement('span');\n", " status_bar.classList = 'mpl-message pull-right';\n", " toolbar.appendChild(status_bar);\n", " this.message = status_bar;\n", "\n", " // Add the close button to the window.\n", " var buttongrp = document.createElement('div');\n", " buttongrp.classList = 'btn-group inline pull-right';\n", " button = document.createElement('button');\n", " button.classList = 'btn btn-mini btn-primary';\n", " button.href = '#';\n", " button.title = 'Stop Interaction';\n", " button.innerHTML = '';\n", " button.addEventListener('click', function (_evt) {\n", " fig.handle_close(fig, {});\n", " });\n", " button.addEventListener(\n", " 'mouseover',\n", " on_mouseover_closure('Stop Interaction')\n", " );\n", " buttongrp.appendChild(button);\n", " var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n", " titlebar.insertBefore(buttongrp, titlebar.firstChild);\n", "};\n", "\n", "mpl.figure.prototype._remove_fig_handler = function (event) {\n", " var fig = event.data.fig;\n", " if (event.target !== this) {\n", " // Ignore bubbled events from children.\n", " return;\n", " }\n", " fig.close_ws(fig, {});\n", "};\n", "\n", "mpl.figure.prototype._root_extra_style = function (el) {\n", " el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n", "};\n", "\n", "mpl.figure.prototype._canvas_extra_style = function (el) {\n", " // this is important to make the div 'focusable\n", " el.setAttribute('tabindex', 0);\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " } else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "};\n", "\n", "mpl.figure.prototype._key_event_extra = function (event, _name) {\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which === 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", " fig.ondownload(fig, null);\n", "};\n", "\n", "mpl.find_output_cell = function (html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i = 0; i < ncells; i++) {\n", " var cell = cells[i];\n", " if (cell.cell_type === 'code') {\n", " for (var j = 0; j < cell.output_area.outputs.length; j++) {\n", " var data = cell.output_area.outputs[j];\n", " if (data.data) {\n", " // IPython >= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] === html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "};\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel !== null) {\n", " IPython.notebook.kernel.comm_manager.register_target(\n", " 'matplotlib',\n", " mpl.mpl_figure_comm\n", " );\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = plt.figure()\n", "ax = fig.add_subplot(projection='3d')\n", "ax.scatter(vertices[:,0], vertices[:,1], vertices[:,2], s=1, \n", " c=skp.LabelEncoder().fit_transform(\n", " face2points(label_faces, faces, len(vertices))))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "f346c997", "metadata": {}, "source": [ "# Section 1: 3D shape statistics with persistence diagrams" ] }, { "cell_type": "markdown", "id": "45fc6b6c", "metadata": {}, "source": [ "In this section, our goal is to compute confidence regions associated to the persistence diagram of a given 3D shape. We will study such regions for both the persistence diagram, and one of its representation, the persistence landscape. " ] }, { "cell_type": "markdown", "id": "9bcb29ff", "metadata": {}, "source": [ "Let's first pick a 3D shape. For instance, use `Hand/181.off` (or any other one you would like to try)." ] }, { "cell_type": "code", "execution_count": 14, "id": "1f6b197f", "metadata": {}, "outputs": [], "source": [ "vertices, faces = off2numpy('3dshapes/Vase/361.off')" ] }, { "cell_type": "code", "execution_count": 15, "id": "c7fc3e05", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e5b92158dd704894b5be39aceea93703", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0170675…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mp.plot(vertices, faces, c=vertices[:,1])" ] }, { "cell_type": "markdown", "id": "a14bf6c8", "metadata": {}, "source": [ "The first standard way of obtaining confidence regions for (geometric) persistence diagrams is through the stability theorem (see class):\n", "\n", "$$\\mathbb{P}(d_b(D_{\\rm Rips}(X),D_{\\rm Rips}(\\hat X_n)) \\geq \\delta)\\leq \\mathbb{P}(d_H(X,\\hat X_n)\\geq \\delta/2),$$\n", "$$\\mathbb{P}(d_b(D_{\\rm Cech}(X),D_{\\rm Cech}(\\hat X_n)) \\geq \\delta)\\leq \\mathbb{P}(d_H(X,\\hat X_n)\\geq \\delta),$$\n", "\n", "where $d_H(\\cdot,\\cdot)$ is the Hausdorff distance, defined, for any two compact spaces $X,Y\\subset \\mathbb{R}^d$, as \n", "\n", "$$d_H(X,Y)={\\rm min}\\{{\\rm max}_{x\\in X}{\\rm min}_{y\\in Y}\\|x-y\\|, {\\rm max}_{y\\in Y}{\\rm min}_{x\\in X}\\|y-x\\|\\}.$$\n", "\n", "Hence, it suffices to estimate $\\mathbb{P}(d_H(X,\\hat X_n)\\geq \\delta)$ in order to derive confidence regions for persistence diagrams. There exists an upper bound for this probability when $\\hat X_n$ is drawn from an $(a,b)$-standard probability measure, however this bound depends on $a$ and $b$. In the following, we will rather use the subsampling method, that allows to estimate the probability solely from subsampling $\\hat X_n$ with $s(n) =o\\left(\\frac{n}{{\\rm log}(n)}\\right)$ points, and computing $d_H(\\hat X_n, \\hat X_{s(n)})$. The exact procedure is described in Section 4.1 in [this article](file:///user/mcarrier/home/Downloads/14-AOS1252.pdf)." ] }, { "cell_type": "markdown", "id": "efa577d9", "metadata": {}, "source": [ "Write a function `hausdorff_distance` that computes the Hausdorff distance between the vertices of our 3D shape and a subset of these vertices." ] }, { "cell_type": "code", "execution_count": null, "id": "791f3ce7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "37abb9ad", "metadata": {}, "source": [ "Now, write a function `hausdorff_interval` that computes this Hausdorff distance many times and uses the corresponding distribution of Hausdorff distances in order to output the bottleneck distance value associated to a given confidence level (by computing the quantile---corresponding to this confidence level---of the distribution)." ] }, { "cell_type": "code", "execution_count": null, "id": "98f767d3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "22241090", "metadata": {}, "source": [ "Apply your code to obtain a bottleneck distance associated to, say, 90% confidence." ] }, { "cell_type": "code", "execution_count": null, "id": "62029e8c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "8b8b2a54", "metadata": {}, "source": [ "All right, now let's see which points of the persistence diagram are we going to label non-significant and discard. Compute the Rips and Alpha persistence diagrams of the points. " ] }, { "cell_type": "code", "execution_count": null, "id": "4d3f5faa", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d01d88d4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "50628fa2", "metadata": {}, "source": [ "Now, visualize the persistence diagrams with a band of size the previously computed bottleneck distance times 2 (for Alpha filtration) and 4 (for Rips filtration)." ] }, { "cell_type": "code", "execution_count": null, "id": "aec230fc", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c0ec2234", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "c8b3253a", "metadata": {}, "source": [ "Are you discarding many points? If yes, this could be because the confidence region is computed only from the stability property of persistence diagrams: subsampling the Hausdorff distance can sometimes be quite conservative. It would be more efficient to bootstrap the persistence diagrams themselves---this is the approach advocated in Section 6 of [this article](https://www.jmlr.org/papers/volume18/15-484/15-484.pdf). However, this method was only proved for persistence diagrams obtained through the sublevel sets of kernel density estimators... But let's try it anyway! ;-)" ] }, { "cell_type": "markdown", "id": "bda2807c", "metadata": {}, "source": [ "Similarly than before, write `bottleneck_distance_bootstrap` and `bottleneck_interval` functions that compute the bottleneck distances between our current persistence diagram (in homology dimension 1) and the persistence diagrams of many bootstrap iterates." ] }, { "cell_type": "code", "execution_count": null, "id": "dc228b99", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "20596388", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "b8f21a77", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "ee9093ae", "metadata": {}, "source": [ "Compute the bottleneck distance associated to a confidence level and visualize it." ] }, { "cell_type": "code", "execution_count": null, "id": "09e6e6d6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "9ee63bab", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "bfdadf23", "metadata": {}, "source": [ "Are you discarding less points in the persistence diagram now?" ] }, { "cell_type": "markdown", "id": "ec3ec80a", "metadata": {}, "source": [ "Another approach with more theoretical guarantees is to use the persistence landscapes associated to the persistence diagram. Indeed, valid confidence regions can be easily obtained using, e.g., algorithm 1 in [this article](https://geometrica.saclay.inria.fr/team/Fred.Chazal/papers/cflrw-scpls-14/cflrw-scpls-14.pdf). In the following, we will fix a subsample size $s(n)$, and estimate $\\mathbb{E}[\\Lambda_{s(n)}]$, where $\\Lambda_{s(n)}$ is the landscape of a random subsample of size $s(n)$ (i.e., drawn from a probability measure $\\mu$ such as, e.g., the empirical measure). " ] }, { "cell_type": "markdown", "id": "aa490204", "metadata": {}, "source": [ "Let's first make sure that we can compute landscapes ;-) Use `Gudhi` to compute and plot the first six persistence landscapes associated to the persistence diagram computed above in homology dimension 1. Landscapes (and other vectorizations) are implemented with the API of `Scikit-Learn` estimators, which means that you have to call the `fit_transform` method on a list of persistence diagrams in order to get their landscapes. " ] }, { "cell_type": "code", "execution_count": null, "id": "7fdb8694", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "079caf08", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4803257c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "690074d2", "metadata": {}, "source": [ "Write a function `landscape_interval` that implements the landscape bootstrap procedure, that is, drawing many subsamples of size $s(n)$, computing their Alpha persistence diagrams and landscapes, computing the distribution of distances between each single landscape and their mean (multiplied by a random normal variable), and finally using the quantiles of this distribution in order to obtain confidence regions for the mean landscape." ] }, { "cell_type": "code", "execution_count": null, "id": "d8bade75", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "dc38520f", "metadata": {}, "source": [ "Apply and visualize the confidence interval around the different landscapes." ] }, { "cell_type": "code", "execution_count": null, "id": "4a6de968", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "927ed0d2", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "9ed1f4c1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "3ae4c7d7", "metadata": {}, "source": [ "The confidence regions are much better now!" ] }, { "cell_type": "markdown", "id": "2f29f592", "metadata": {}, "source": [ "Another interesting property of mean landscapes is their robustness to noise:\n", "\n", "$$\\|\\mathbb{E}[\\Lambda_{s(n)}^X]-\\mathbb{E}[\\Lambda_{s(n)}^Y]\\|_\\infty\\leq 2 \\cdot s(n) \\cdot d_{GW}(\\mu,\\nu),$$\n", "\n", "where $d_{GW}$ is the 1-Gromov-Wasserstein distance between probability measures. See Remark 6 in [this article](https://geometrica.saclay.inria.fr/team/Fred.Chazal/papers/cflmrw-smph-15/ICMLFinal.pdf). We will now confirm this by adding outlier noise to the 3D shape and looking at the resulting mean landscape. " ] }, { "cell_type": "markdown", "id": "a68c506e", "metadata": {}, "source": [ "Create a noisy version of `vertices` with some outlier noise." ] }, { "cell_type": "code", "execution_count": null, "id": "f028e220", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "50a4ceaa", "metadata": {}, "source": [ "Let's first compare the persistence landscapes of the two sets of vertices. Compute and visualize these landscapes on the same plot." ] }, { "cell_type": "code", "execution_count": null, "id": "276bddd7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "f9c6fb84", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "df2a619e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "473c0f2e", "metadata": {}, "source": [ "As one can see, they are quite different. By contrast, computing the mean landscape with subsamples is much more robust, as we will now see." ] }, { "cell_type": "markdown", "id": "ab6bb739", "metadata": {}, "source": [ "Compute the mean persistence landscape of the noisy point cloud, and visualize it next to the mean persistence landscape of the clean point cloud." ] }, { "cell_type": "code", "execution_count": null, "id": "02ae1659", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "abf6231e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "d98e06b2", "metadata": {}, "source": [ "Now, these landscapes looks much more in agreement!" ] }, { "cell_type": "markdown", "id": "6dd46d4a", "metadata": {}, "source": [ "# Section 2: 3D shape classification with persistence diagrams" ] }, { "cell_type": "markdown", "id": "4613881a", "metadata": {}, "source": [ "In this section, our goal is to use persistence diagrams for classifying and segmenting 3D shapes with supervised machine learning. " ] }, { "cell_type": "markdown", "id": "784b38f1", "metadata": {}, "source": [ "Let's start with classification. We will compute persistence diagrams for all shapes in different categories, and train a classifier from `Scikit-Learn` to predict the category from the persistence diagrams. Since `Gudhi` requires simplex trees from the persistence diagram computations, write a `get_simplex_tree_from_faces` function that builds a simplex tree from the faces of a given 3D shape triangulation." ] }, { "cell_type": "code", "execution_count": null, "id": "5e11b0d8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "0676683a", "metadata": {}, "source": [ "Now, compute all the persistence diagrams (in homology dimension 0) associated to the sublevel sets of the third coordinate from a few categories, and retrieve their corresponding labels." ] }, { "cell_type": "code", "execution_count": null, "id": "da0283a9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "a5f46e1e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "260189d9", "metadata": {}, "source": [ "As discussed in class, it is not very convenient to use persistence diagrams directly for machine learning purposes (except for a few methods such as $K$-nearest neighbors). What we need is to define a vectorization, that is, a map $\\Phi:\\mathcal{D}\\rightarrow\\mathcal{H}$ sending persistence diagrams into a Hilbert space, or equivalently, a symmetric kernel function $k:\\mathcal{D}\\times \\mathcal{D} \\rightarrow \\mathbb{R}$ such that $k(D,D')=\\langle \\Phi(D),\\Phi(D')\\rangle$. Fortunately, there are already a bunch of such maps and kernels in `Gudhi` :-)" ] }, { "cell_type": "markdown", "id": "60984d64", "metadata": {}, "source": [ "In the following we will compute and visualize the most popular kernels on some persistence diagrams. Pick first a specific persistence diagram and use `DiagramSelector` to remove its points with infinite coordinates." ] }, { "cell_type": "code", "execution_count": null, "id": "2963ea21", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "9761bda2", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "56872c58", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "0f3726d8", "metadata": {}, "source": [ "Now, let's see what `Gudhi` has to offer to vectorize persistence diagrams with `Scikit-Learn` estimator-like classes, that is, with classes that have `fit`, `transform`, and `fit_transform` methods, see [this article](https://arxiv.org/pdf/1309.0238.pdf) for more details. For each vectorization mentioned below, we recommend you to play with its parameters and infer their influence on the ouput in order to get some intuition. " ] }, { "cell_type": "markdown", "id": "fd08cc20", "metadata": {}, "source": [ "The first vectorization method that was introduced historically is the persistence landscape. A persistence landscape is basically obtained by rotating the persistence diagram by $-\\pi/4$\n", "(so that the diagonal becomes the $x$-axis), and then putting tent functions on each point. The $k$th landscape is then defined as the $k$th largest value among all these tent functions. It is eventually turned into a vector by evaluating it on a bunch of uniformly sampled points on the $x$-axis." ] }, { "cell_type": "markdown", "id": "21913ef9", "metadata": {}, "source": [ "Compute and visualize the first landscape of the persistence diagram for various parameters." ] }, { "cell_type": "code", "execution_count": null, "id": "427e5caa", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "efbe4807", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "f52bfe1a", "metadata": {}, "source": [ "A variation, called the silhouette, takes a weighted average of these tent functions instead. Here, we weight each tent function by the distance of the corresponding point to the diagonal." ] }, { "cell_type": "code", "execution_count": null, "id": "84d1dc1b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "a7273228", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "12dcbec7", "metadata": {}, "source": [ "The second method is the persistence image. A persistence image is obtained by rotating by $-\\pi/4$, centering Gaussian functions on all diagram points (usually weighted by a parameter function, such as, e.g., the squared distance to the diagonal) and summing all these Gaussians. This gives a 2D function, that is pixelized into an image." ] }, { "cell_type": "code", "execution_count": null, "id": "96fcc21d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4c45a15a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "6a655394", "metadata": {}, "source": [ "`Gudhi` also has a variety of metrics and kernels, which sometimes perform better than explicit vectorizations such as the ones described above. Pick another persistence diagram, and get familiar with the bottleneck and the Wasserstein distances between them. Note that you can call them in different ways in `Gudhi`, there are `bottleneck_distance` and `wasserstein_distance` functions for instance, but there are also wrappers of these functions into estimator classes `BottleneckDistance` and `WassersteinDistance` (with `fit` and `transform` methods). These classes are especially useful when doing model selection with `Scikit-Learn`, see below." ] }, { "cell_type": "code", "execution_count": null, "id": "735593c9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1bce755c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "96a1bad0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "35501ba5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "4b1bfa66", "metadata": {}, "source": [ "`Gudhi` also allows to use standard kernels such as, among others, the persistence scale space kernel, persistence Fisher kernel, sliced Wasserstein kernel, etc. Try computing the kernel values for your pair of diagrams." ] }, { "cell_type": "code", "execution_count": null, "id": "120e07bb", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "3a6c5946", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "0dfc927b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "19ad7499", "metadata": {}, "source": [ "Before trying to classify the persistence diagrams, let's do a quick dimension reduction with PCA. Apply `PCA`, `KernelPCA` or `MDS` (available in `Scikit-Learn`) on the explicit maps (landscapes, images, etc), kernel matrices (Fisher, sliced Wasserstein, etc) and distance matrices (bottleneck, Wasserstein, etc) respectively." ] }, { "cell_type": "code", "execution_count": null, "id": "0af6434a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "22c9b448", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "24cbd9fd", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "434e5ace", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "9319273b", "metadata": { "scrolled": false }, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "210cd09a", "metadata": {}, "source": [ "Is there any method that looks better in separating the categories, at least by eye?" ] }, { "cell_type": "markdown", "id": "9f0cd6b2", "metadata": {}, "source": [ "All right, let's try classification now! Shuffle the data, and create a random 80/20 train/test split." ] }, { "cell_type": "code", "execution_count": null, "id": "c3a5be44", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "c289041b", "metadata": {}, "source": [ "Here is the best thing about having estimator-like classes: they can be integrated flawlessly in a `Pipeline` of `Scikit-Learn` for model selection and cross-validation! A `Pipeline` is itself an estimator, and is initialized as with a list of estimators. It will just sequentially apply the `fit_transform` methods of the estimators in the list." ] }, { "cell_type": "markdown", "id": "f376b990", "metadata": {}, "source": [ "Define a `Pipeline` with four estimators: one for selecting the finite persistence diagram points, one for scaling (or not) the persistence diagrams (with `DiagramScaler`), one for vectorizing persistence diagrams, and one for performing the final prediction. See the [documentation](https://scikit-learn.org/stable/modules/compose.html#combining-estimators)." ] }, { "cell_type": "code", "execution_count": null, "id": "83880198", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "09de4257", "metadata": {}, "source": [ "Now, define a grid of parameter that will be used in cross-validation." ] }, { "cell_type": "code", "execution_count": null, "id": "c811a19b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "ed78f354", "metadata": {}, "source": [ "Define and train the model." ] }, { "cell_type": "code", "execution_count": null, "id": "b718b3f0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c4a88c9a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "0e5eaee3", "metadata": {}, "source": [ "Check the parameters that were chosen during model selection, and evaluate your model on the test set." ] }, { "cell_type": "code", "execution_count": null, "id": "bdfd2f2e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "46647aba", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "dc84af79", "metadata": {}, "source": [ "How good is your score? How would you improve it? You can also try to use PersLay to learn which representation to use, see this [tutorial](https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-perslay-visu.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "id": "03da67b7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }