mlcolvar.core.loss.ELBOGaussiansLoss¶
- class mlcolvar.core.loss.ELBOGaussiansLoss(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleELBO loss function assuming the latent and reconstruction distributions are Gaussian.
The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the decoder outputs the mean of a Gaussian distribution with variance 1), and the KL divergence between two normal distributions
N(mean, var)andN(0, 1), wheremeanandvarare the output of the encoder.- __init__(*args: Any, **kwargs: Any) None¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
forward(target, output, mean, log_variance)Compute the value of the loss function.
- forward(target: Tensor, output: Tensor, mean: Tensor, log_variance: Tensor, beta: float = 1.0, return_loss_terms: bool = False, weights: Tensor | None = None) Tensor[source]¶
Compute the value of the loss function.
- Parameters:
target (torch.Tensor) – Shape
(n_batches, in_features). Data points (e.g. input of encoder or time-lagged features).output (torch.Tensor) – Shape
(n_batches, in_features). Output of the decoder.mean (torch.Tensor) – Shape
(n_batches, latent_features). The means of the Gaussian distributions associated to the inputs.log_variance (torch.Tensor) – Shape
(n_batches, latent_features). The logarithm of the variances of the Gaussian distributions associated to the inputs.beta (float, optional) – A scaling factor for the KL divergence term. The default is 1.0, which means that the KL divergence is not scaled. If set to a value greater than 1, it will increase the weight of the KL divergence term in the loss function (useful to increase regularization). If set to a value less than 1, it will decrease the weight of the KL divergence term (useful to avoid posterior collapse)
return_loss_terms (bool, optional) – If
True, besides to total loss, return the two main terms of the ELBO separately (reconstruction loss and KL divergence). The default isFalse, which returns just the total loss.weights (torch.Tensor, optional) – Shape
(n_batches,)or(n_batches,1). If given, the average over batches is weighted. The default (None) is unweighted.
- Returns:
loss – The value of the loss function.
- Return type:
torch.Tensor
Attributes
T_destination
call_super_init
dump_patches
training