ott.neural.networks.icnn.KeyNet.gradient

ott.neural.networks.icnn.KeyNet.gradient#

KeyNet.gradient(x)[source]#

Compute the vector output (predicted gradient / key).

Parameters:

x (Array) – Input of shape [batch_size, input_dim].

Return type:

Array

Returns:

Output of shape [batch_size, output_dim] or [batch_size, num_outputs, output_dim].