mlcolvar.core.loss.ELBOGaussiansLoss

class mlcolvar.core.loss.ELBOGaussiansLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

ELBO loss function assuming the latent and reconstruction distributions are Gaussian.

The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the decoder outputs the mean of a Gaussian distribution with variance 1), and the KL divergence between two normal distributions N(mean, var) and N(0, 1), where mean and var are the output of the encoder.

__init__(*args: Any, **kwargs: Any) None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods

forward(target, output, mean, log_variance)

Compute the value of the loss function.

forward(target: Tensor, output: Tensor, mean: Tensor, log_variance: Tensor, beta: float = 1.0, return_loss_terms: bool = False, weights: Tensor | None = None) Tensor[source]

Compute the value of the loss function.

Parameters:
  • target (torch.Tensor) – Shape (n_batches, in_features). Data points (e.g. input of encoder or time-lagged features).

  • output (torch.Tensor) – Shape (n_batches, in_features). Output of the decoder.

  • mean (torch.Tensor) – Shape (n_batches, latent_features). The means of the Gaussian distributions associated to the inputs.

  • log_variance (torch.Tensor) – Shape (n_batches, latent_features). The logarithm of the variances of the Gaussian distributions associated to the inputs.

  • beta (float, optional) – A scaling factor for the KL divergence term. The default is 1.0, which means that the KL divergence is not scaled. If set to a value greater than 1, it will increase the weight of the KL divergence term in the loss function (useful to increase regularization). If set to a value less than 1, it will decrease the weight of the KL divergence term (useful to avoid posterior collapse)

  • return_loss_terms (bool, optional) – If True, besides to total loss, return the two main terms of the ELBO separately (reconstruction loss and KL divergence). The default is False, which returns just the total loss.

  • weights (torch.Tensor, optional) – Shape (n_batches,) or (n_batches,1). If given, the average over batches is weighted. The default (None) is unweighted.

Returns:

loss – The value of the loss function.

Return type:

torch.Tensor

Attributes

T_destination

call_super_init

dump_patches

training