# Point clouds#

We cover in this tutorial the instantiation and use of a PointCloud geometry.

A PointCloud geometry holds two arrays of vectors, endowed with a cost function. Such a geometry should cover most users’ needs.

We further show differentiation through optimal transport as an example of optimization that leverages first-order gradients.

[ ]:

import sys

if "google.colab" in sys.modules:
!pip install -q git+https://github.com/ott-jax/ott@main

:

import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import ott
from ott.geometry import pointcloud
from ott.core import sinkhorn
from ott.tools import transport


## Creates a PointCloud geometry#

:

def create_points(rng, n, m, d):
rngs = jax.random.split(rng, 3)
x = jax.random.normal(rngs, (n, d)) + 1
y = jax.random.uniform(rngs, (m, d))
a = jnp.ones((n,)) / n
b = jnp.ones((m,)) / m
return x, y, a, b

rng = jax.random.PRNGKey(0)
n, m, d = 12, 14, 2
x, y, a, b = create_points(rng, n=n, m=m, d=d)


## Computes the regularized optimal transport#

To compute the transport matrix between the two point clouds, one can define a PointCloud geometry (which by default uses ott.geometry.costs.Euclidean for cost function), then call the sinkhorn function, and build the transport matrix from the optimized potentials.

:

geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
out = sinkhorn.sinkhorn(geom, a, b)
P = geom.transport_from_potentials(out.f, out.g)


A more concise syntax to compute the optimal transport matrix is to use the transport.solve. Note how weights are assumed to be uniform if no parameter a and b is passed to transport.solve.

:

ot = transport.solve(x, y, a=a, b=b, epsilon=1e-2)


## Visualizes the transport#

:

plt.imshow(ot.matrix, cmap="Purples")
plt.colorbar(); :

plott = ott.tools.plot.Plot()
_ = plott(ot) ## Differentiation through Optimal Transport#

OTT returns quantities that are differentiable. In the following example, we leverage the gradients to move N points in a way that minimizes the overall regularized OT cost, given a ground cost function, here the squared Euclidean distance.

:

def optimize(
x: jnp.ndarray,
y: jnp.ndarray,
a: jnp.ndarray,
b: jnp.ndarray,
cost_fn=ott.geometry.costs.Euclidean(),
num_iter: int = 101,
dump_every: int = 10,
learning_rate: float = 0.2,
):
jax.jit(
(
lambda geom, a, b: ott.core.sinkhorn.sinkhorn(
geom, a, b
).reg_ot_cost
)
),
argnums=0,
)

ot = transport.solve(
x, y, a=a, b=b, cost_fn=cost_fn, epsilon=1e-2, jit=True
)
result = [ot]
for i in range(1, num_iter + 1):
reg_ot_cost, geom_g = reg_ot_cost_vg(ot.geom, ot.a, ot.b)
x = x - geom_g.x * learning_rate
ot = transport.solve(
x, y, a=a, b=b, cost_fn=cost_fn, epsilon=1e-2, jit=True
)
if i % dump_every == 0:
result.append(ot)

return result

:

from IPython import display

ots = optimize(x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.Euclidean())
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
anim = plott.animate(ots, frame_rate=4)
html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()


We could use another cost function, in this case Cosine distance, to achieve another kind of dynamics in optimization.

:

ots = optimize(x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.Cosine())
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
anim = plott.animate(ots, frame_rate=8)
html = display.HTML(anim.to_jshtml())
display.display(html)
plt.close()