ott.tools.sinkhorn_divergence.sinkhorn_divergence
ott.tools.sinkhorn_divergence.sinkhorn_divergence#
- ott.tools.sinkhorn_divergence.sinkhorn_divergence(geom, *args, a=None, b=None, sinkhorn_kwargs=None, static_b=False, share_epsilon=True, **kwargs)[source]#
Compute Sinkhorn divergence defined by a geometry, weights, parameters.
- Parameters
geom (
Geometry
) – a geometry class.args (
Any
) – arguments toott.geometry.geometry.Geometry.prepare_divergences()
that is specific to each geometry.a (
Optional
[ndarray
]) – jnp.ndarray<float>[n]: the weight of each input point. The sum of all elements of b must match that of a to converge.b (
Optional
[ndarray
]) – jnp.ndarray<float>[m]: the weight of each target point. The sum of all elements of b must match that of a to converge.sinkhorn_kwargs (
Optional
[Dict
[str
,Any
]]) – Optionally a dict containing the keywords arguments for calls to the sinkhorn function, that is called twice if static_b else three times.static_b (
bool
) – if True, divergence of measure b against itself is NOT computedshare_epsilon (
bool
) – if True, enforces that the same epsilon regularizer is shared for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon will be by default that used when comparing x to y (contained in the first geometry). This flag is set to True by default, because in the default setting, the epsilon regularization is a function of the mean of the cost matrix.kwargs (
Any
) – keywords arguments to the generic class. This is specific to each geometry.
- Returns
(sinkhorn divergence value, three pairs of potentials, three costs)
- Return type