# Copyright OTT-JAX
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jax.scipy as jsp

from ott import utils
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud
from ott.problems.linear import linear_problem
from ott.types import Transport

if TYPE_CHECKING:
from ott.solvers.linear import sinkhorn_lr

[docs]
@jax.tree_util.register_pytree_node_class

The quadratic loss of a single OT matrix is assumed to
have the form given in :cite:peyre:16, eq. 4.

The two geometries below parameterize matrices :math:C and :math:\bar{C}
in that equation. The function :math:L (of two real values) in that equation
is assumed to match the form given in eq. 5., with our notations:

.. math::
L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)

Args:
geom_xx: Ground geometry of the first space.
geom_yy: Ground geometry of the second space.
geom_xy: Geometry defining the linear penalty term for
fused Gromov-Wasserstein :cite:vayer:19. If :obj:None, the problem
reduces to a plain Gromov-Wasserstein problem :cite:peyre:16.
fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein,
i.e. problem = purely quadratic + fused_penalty * linear problem.
scale_cost: How to rescale the cost matrices. If a :class:str,
use specific options available in :class:~ott.geometry.geometry.Geometry
or :class:~ott.geometry.pointcloud.PointCloud. If :obj:None, keep
the original scaling.
a: The first marginal. If :obj:None, it will be uniform.
b: The second marginal. If :obj:None, it will be uniform.
loss: Gromov-Wasserstein loss function, see
:class:~ott.problems.quadratic.quadratic_costs.GWLoss for more
information.
tau_a: If :math:< 1.0, defines how much unbalanced the problem is on
the first marginal.
tau_b: If :math:< 1.0, defines how much unbalanced the problem is on
the second marginal.
gw_unbalanced_correction: Whether the unbalanced version of
:cite:sejourne:21 is used. Otherwise, tau_a and tau_b
only affect the inner Sinkhorn loop.
ranks: Ranks of the cost matrices, see
:meth:~ott.geometry.geometry.Geometry.to_LRCGeometry. Used when
geometries are *not* :class:~ott.geometry.pointcloud.PointCloud with
'sqeucl' cost function. If -1, the geometries will not be converted
to low-rank. If :class:tuple, it specifies the ranks of geom_xx,
geom_yy and geom_xy, respectively. If :class:int, rank is shared
across all geometries.
tolerances: Tolerances used when converting geometries to low-rank. Used
when geometries are not :class:~ott.geometry.pointcloud.PointCloud with
'sqeucl' cost. If :class:float, it is shared across all geometries.
"""

def __init__(
self,
geom_xx: geometry.Geometry,
geom_yy: geometry.Geometry,
geom_xy: Optional[geometry.Geometry] = None,
fused_penalty: float = 1.0,
scale_cost: Optional[Union[float, str]] = None,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl",
tau_a: float = 1.0,
tau_b: float = 1.0,
gw_unbalanced_correction: bool = True,
ranks: Union[int, Tuple[int, ...]] = -1,
tolerances: Union[float, Tuple[float, ...]] = 1e-2,
):
if scale_cost is not None:
geom_xx = geom_xx.set_scale_cost(scale_cost)
geom_yy = geom_yy.set_scale_cost(scale_cost)
if geom_xy is not None:
geom_xy = geom_xy.set_scale_cost(scale_cost)

self._geom_xx = geom_xx
self._geom_yy = geom_yy
self._geom_xy = geom_xy
self.fused_penalty = fused_penalty
self.scale_cost = scale_cost
self._a = a
self._b = b
self.tau_a = tau_a
self.tau_b = tau_b
self.gw_unbalanced_correction = gw_unbalanced_correction
self.ranks = ranks
self.tolerances = tolerances

self._loss_name = loss
if self._loss_name == "sqeucl":
elif loss == "kl":
else:
self.loss = loss

[docs]
def marginal_dependent_cost(
self,
marginal_1: jnp.ndarray,
marginal_2: jnp.ndarray,
) -> low_rank.LRCGeometry:
r"""Initialize cost term that depends on the marginals of the transport.

Uses the first term in eq. 6, p. 1 of :cite:peyre:16.

Let :math:p be the [n,] marginal of the transport matrix for samples
from :attr:geom_xx and :math:q the [m,] marginal of the
transport matrix for samples from :attr:geom_yy.

When cost_xx (resp. cost_yy) is the cost matrix of :attr:geom_xx
(resp. :attr:geom_yy), the cost term that depends on these marginals can
be written as:

.. math::

\text{marginal_dep_term} = \text{lin1}(\text{cost_xx}) p \mathbb{1}_{m}^T
+  \mathbb{1}_{n}(\text{lin2}(\text{cost_yy}) q)^T

This helper function instantiates these two low-rank matrices and groups
them into a single low-rank cost geometry object.

Args:
marginal_1: [n,], first marginal of transport matrix.
marginal_2: [m,], second marginal of transport matrix.

Returns:
Low-rank geometry of rank 2, storing normalization constants.
"""
geom_xx, geom_yy = self.geom_xx, self.geom_yy
if self._loss_name == "sqeucl":  # quadratic apply, efficient for LR
tmp1 = geom_xx.apply_square_cost(marginal_1, axis=1)
tmp2 = geom_yy.apply_square_cost(marginal_2, axis=1)
else:
f1, f2 = self.linear_loss
tmp1 = apply_cost(geom_xx, marginal_1, axis=1, fn=f1)
tmp2 = apply_cost(geom_yy, marginal_2, axis=1, fn=f2)
x_term = jnp.concatenate((tmp1, jnp.ones_like(tmp1)), axis=1)
y_term = jnp.concatenate((jnp.ones_like(tmp2), tmp2), axis=1)
return low_rank.LRCGeometry(cost_1=x_term, cost_2=y_term)

[docs]
def cost_unbalanced_correction(
self,
transport_matrix: jnp.ndarray,
marginal_1: jnp.ndarray,
marginal_2: jnp.ndarray,
epsilon: epsilon_scheduler.Epsilon,
) -> float:
r"""Calculate cost term from the quadratic divergence when unbalanced.

In the unbalanced setting (tau_a < 1.0 or tau_b < 1.0), the
introduction of a quadratic divergence :cite:sejourne:21 adds a term
to the GW local cost.

Let :math:a [num_a,] be the target weights for samples
from geom_xx and :math:b [num_b,] be the target weights
for samples from geom_yy. Let :math:P [num_a, num_b] be the transport
matrix, :math:P1 the first marginal and :math:P^T1 the second marginal.
The term of the cost matrix coming from the quadratic KL in the
unbalanced case can be written as:

unbalanced_correction_term =
:math:tau_a / (1 - tau_a) * \sum(KL(P1|a))
:math:+ tau_b / (1 - tau_b) * \sum(KL(P^T1|b))
:math:+ epsilon * \sum(KL(P|ab'))

Args:
transport_matrix: jnp.ndarray<float>[num_a, num_b], transport matrix.
marginal_1: jnp.ndarray<float>[num_a,], marginal of the transport matrix
for samples from :attr:geom_xx.
marginal_2: jnp.ndarray<float>[num_b,], marginal of the transport matrix
for samples from :attr:geom_yy.
epsilon: entropy regularizer.

Returns:
The cost term.
"""

def regularizer(tau: float) -> float:
return eps * tau / (1.0 - tau)

eps = epsilon._target_init
marginal_1loga = jsp.special.xlogy(marginal_1, self.a).sum()
marginal_2logb = jsp.special.xlogy(marginal_2, self.b).sum()

cost = eps * jsp.special.xlogy(transport_matrix, transport_matrix).sum()
if self.tau_a != 1.0:
cost += regularizer(
self.tau_a
) * (-jsp.special.entr(marginal_1).sum() - marginal_1loga)
if self.tau_b != 1.0:
cost += regularizer(
self.tau_b
) * (-jsp.special.entr(marginal_2).sum() - marginal_2logb)
return cost

# TODO(michalk8): highly coupled to the pre-defined initializer, refactor

[docs]
def init_transport_mass(self) -> float:
"""Initialize the transport mass.

Returns:
The sum of the elements of the normalized transport matrix.
"""
return a.sum() * b.sum()

[docs]
def update_lr_geom(
self,
lr_sink: "sinkhorn_lr.LRSinkhornOutput",
relative_epsilon: Optional[bool] = None,
) -> geometry.Geometry:
"""Recompute (possibly LRC) linearization using LR Sinkhorn output."""
marginal_1 = lr_sink.marginal(1)
marginal_2 = lr_sink.marginal(0)
marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2)

# Extract factors from LR Sinkhorn output
q, r, inv_sqg = lr_sink.q, lr_sink.r, 1.0 / jnp.sqrt(lr_sink.g)
# Distribute middle marginal evenly across both factors.
q, r = q * inv_sqg[None, :], r * inv_sqg[None, :]

# Handle LRC Geometry case.
geom_xx, geom_yy, geom_xy = self.geom_xx, self.geom_yy, self.geom_xy
tmp1 = apply_cost(geom_xx, q, axis=1, fn=h1)
tmp2 = apply_cost(geom_yy, r, axis=1, fn=h2)
if self.is_low_rank:
geom = low_rank.LRCGeometry(
cost_1=tmp1, cost_2=-tmp2, relative_epsilon=relative_epsilon
) + marginal_cost
if self.is_fused:
geom = geom + geom_xy
else:
cost_matrix = marginal_cost.cost_matrix - jnp.dot(tmp1, tmp2.T)
cost_matrix += self.fused_penalty * self._fused_cost_matrix
geom = geometry.Geometry(
cost_matrix=cost_matrix, relative_epsilon=relative_epsilon
)
return geom  # noqa: RET504

[docs]
def update_linearization(
self,
transport: Transport,
epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None,
old_transport_mass: float = 1.0,
relative_epsilon: Optional[bool] = None,
) -> linear_problem.LinearProblem:
"""Update linearization of GW problem by updating cost matrix.

If the problem is balanced (tau_a = 1.0 and tau_b = 1.0), the equation
follows eq. 6, p. 1 of :cite:peyre:16.

If the problem is unbalanced (tau_a < 1.0 or tau_b < 1.0), two cases are
possible, as explained in :meth:init_linearization above.
Finally, it is also possible to consider a Fused Gromov-Wasserstein problem.
Details about the resulting cost matrix are also given in
:meth:init_linearization.

Args:
transport: Solution of the linearization of the quadratic problem.
epsilon: An epsilon scheduler or a float passed on to the linearization.
old_transport_mass: Sum of the elements of the transport matrix at the
previous iteration.
relative_epsilon: Whether to use relative epsilon in the linearized
geometry.

Returns:
Updated linear OT problem, a new local linearization of GW problem.
"""
rescale_factor = 1.0
unbalanced_correction = 0.0

if not self.is_balanced:
marginal_1 = transport.marginal(axis=1)
rescale_factor = jnp.sqrt(old_transport_mass / transport_mass)

marginal_1 = transport.marginal(axis=1) * rescale_factor
marginal_2 = transport.marginal(axis=0) * rescale_factor
marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2)

transport_matrix = transport.matrix * rescale_factor

if not self.is_balanced:
# Rescales transport for Unbalanced GW according to Sejourne et al. (2021)
epsilon = update_epsilon_unbalanced(epsilon, transport_mass)
unbalanced_correction = self.cost_unbalanced_correction(
transport_matrix, marginal_1, marginal_2, epsilon
)

geom_xx, geom_yy = self.geom_xx, self.geom_yy

tmp = apply_cost(geom_xx, transport_matrix, axis=1, fn=h1)
tmp = apply_cost(geom_yy, tmp.T, axis=1, fn=h2).T

cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction
cost_matrix += self.fused_penalty * rescale_factor * self._fused_cost_matrix

geom = geometry.Geometry(
cost_matrix=cost_matrix,
epsilon=epsilon,
relative_epsilon=relative_epsilon
)

return linear_problem.LinearProblem(
geom, self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b
)

[docs]
def update_lr_linearization(
self,
lr_sink: "sinkhorn_lr.LRSinkhornOutput",
*,
relative_epsilon: Optional[bool] = None,
) -> linear_problem.LinearProblem:
"""Update a Quad problem linearization using a LR Sinkhorn."""
return linear_problem.LinearProblem(
self.update_lr_geom(lr_sink, relative_epsilon=relative_epsilon),
self.a,
self.b,
tau_a=self.tau_a,
tau_b=self.tau_b
)

@property
def _fused_cost_matrix(self) -> Union[float, jnp.ndarray]:
if not self.is_fused:
return 0.0
geom_xy = self.geom_xy

if isinstance(geom_xy, pointcloud.PointCloud) and geom_xy.is_online:
return geom_xy._compute_cost_matrix() * geom_xy.inv_scale_cost
return geom_xy.cost_matrix

@property
def _is_low_rank_convertible(self) -> bool:

def convertible(geom: geometry.Geometry) -> bool:
return isinstance(geom, low_rank.LRCGeometry) or (
isinstance(geom, pointcloud.PointCloud) and geom.is_squared_euclidean
)

if self.is_low_rank:
return True

geom_xx, geom_yy, geom_xy = self.geom_xx, self.geom_yy, self.geom_xy
# either explicitly via cost factorization or implicitly (e.g., a PC)
return self.ranks != -1 or (
convertible(geom_xx) and convertible(geom_yy) and
(geom_xy is None or convertible(geom_xy))
)

[docs]
def to_low_rank(
self,
rng: Optional[jax.Array] = None,
"""Convert geometries to low-rank.

Args:
rng: Random key for seeding.

Returns:
"""

def convert(
vals: Union[int, float, Tuple[Union[int, float], ...]]
) -> Tuple[Union[int, float], ...]:
size = 2 + self.is_fused
if isinstance(vals, (int, float)):
return (vals,) * 3
assert len(vals) == size, vals
return vals + (None,) * (3 - size)

if self.is_low_rank:
return self

rng = utils.default_prng_key(rng)
rng1, rng2, rng3 = jax.random.split(rng, 3)
(geom_xx, geom_yy, geom_xy, *children), aux_data = self.tree_flatten()
(r1, r2, r3), (t1, t2, t3) = convert(self.ranks), convert(self.tolerances)

geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, rng=rng1)
geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, rng=rng2)
if self.is_fused:
if isinstance(
geom_xy, pointcloud.PointCloud
) and geom_xy.is_squared_euclidean:
geom_xy = geom_xy.to_LRCGeometry(scale=self.fused_penalty)
else:
geom_xy = geom_xy.to_LRCGeometry(
rank=r3, tol=t3, rng=rng3, scale=self.fused_penalty
)

return type(self).tree_unflatten(
aux_data, [geom_xx, geom_yy, geom_xy] + children
)

@property
def geom_xx(self) -> geometry.Geometry:
"""Geometry of the first space."""
return self._geom_xx

@property
def geom_yy(self) -> geometry.Geometry:
"""Geometry of the second space."""
return self._geom_yy

@property
def geom_xy(self) -> Optional[geometry.Geometry]:
"""Geometry of the joint space."""
return self._geom_xy

@property
def a(self) -> jnp.ndarray:
"""First marginal."""
num_a = self.geom_xx.shape[0]
return jnp.ones((num_a,)) / num_a if self._a is None else self._a

@property
def b(self) -> jnp.ndarray:
"""Second marginal."""
num_b = self.geom_yy.shape[0]
return jnp.ones((num_b,)) / num_b if self._b is None else self._b

@property
def is_fused(self) -> bool:
"""Whether the problem is fused."""
return self.geom_xy is not None

@property
def is_low_rank(self) -> bool:
"""Whether all geometries are low-rank."""
return (
isinstance(self.geom_xx, low_rank.LRCGeometry) and
isinstance(self.geom_yy, low_rank.LRCGeometry) and
(not self.is_fused or isinstance(self.geom_xy, low_rank.LRCGeometry))
)

@property
"""Linear part of the Gromov-Wasserstein loss."""
return self.loss.f1, self.loss.f2

@property
"""Quadratic part of the Gromov-Wasserstein loss."""
return self.loss.h1, self.loss.h2

@property
def is_balanced(self) -> bool:
"""Whether the problem is balanced."""
return ((not self.gw_unbalanced_correction) or
(self.tau_a == 1.0 and self.tau_b == 1.0))

def tree_flatten(self):  # noqa: D102
return ([self.geom_xx, self.geom_yy, self.geom_xy, self._a, self._b], {
"tau_a": self.tau_a,
"tau_b": self.tau_b,
"loss": self._loss_name,
"fused_penalty": self.fused_penalty,
"scale_cost": self.scale_cost,
"gw_unbalanced_correction": self.gw_unbalanced_correction,
"ranks": self.ranks,
"tolerances": self.tolerances
})

@classmethod
def tree_unflatten(cls, aux_data, children):  # noqa: D102
geoms, (a, b) = children[:3], children[3:]
return cls(*geoms, a=a, b=b, **aux_data)

def update_epsilon_unbalanced(  # noqa: D103
epsilon: Union[float, epsilon_scheduler.Epsilon], transport_mass: float
) -> epsilon_scheduler.Epsilon:
if not isinstance(epsilon, epsilon_scheduler.Epsilon):
epsilon = epsilon_scheduler.Epsilon(epsilon, scale_epsilon=1.0)
return epsilon.set(scale_epsilon=epsilon._scale_epsilon * transport_mass)

def apply_cost(  # noqa: D103
geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int,