Source code for mlcolvar.core.loss.elbo

#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Evidence Lower BOund (ELBO) loss functions used to train variational Autoencoders.
"""

__all__ = ["ELBOGaussiansLoss", "elbo_gaussians_loss"]


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Optional
import torch
from mlcolvar.core.loss.mse import mse_loss


# =============================================================================
# LOSS FUNCTIONS
# =============================================================================


[docs] class ELBOGaussiansLoss(torch.nn.Module): """ELBO loss function assuming the latent and reconstruction distributions are Gaussian. The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the decoder outputs the mean of a Gaussian distribution with variance 1), and the KL divergence between two normal distributions ``N(mean, var)`` and ``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder. """
[docs] def forward( self, target: torch.Tensor, output: torch.Tensor, mean: torch.Tensor, log_variance: torch.Tensor, beta: float = 1.0, return_loss_terms: bool = False, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Compute the value of the loss function. Parameters ---------- target : torch.Tensor Shape ``(n_batches, in_features)``. Data points (e.g. input of encoder or time-lagged features). output : torch.Tensor Shape ``(n_batches, in_features)``. Output of the decoder. mean : torch.Tensor Shape ``(n_batches, latent_features)``. The means of the Gaussian distributions associated to the inputs. log_variance : torch.Tensor Shape ``(n_batches, latent_features)``. The logarithm of the variances of the Gaussian distributions associated to the inputs. beta : float, optional A scaling factor for the KL divergence term. The default is 1.0, which means that the KL divergence is not scaled. If set to a value greater than 1, it will increase the weight of the KL divergence term in the loss function (useful to increase regularization). If set to a value less than 1, it will decrease the weight of the KL divergence term (useful to avoid posterior collapse) return_loss_terms : bool, optional If ``True``, besides to total loss, return the two main terms of the ELBO separately (reconstruction loss and KL divergence). The default is ``False``, which returns just the total loss. weights : torch.Tensor, optional Shape ``(n_batches,)`` or ``(n_batches,1)``. If given, the average over batches is weighted. The default (``None``) is unweighted. Returns ------- loss: torch.Tensor The value of the loss function. """ return elbo_gaussians_loss(target, output, mean, log_variance, beta, return_loss_terms, weights)
def elbo_gaussians_loss( target: torch.Tensor, output: torch.Tensor, mean: torch.Tensor, log_variance: torch.Tensor, beta: float = 1.0, return_loss_terms: bool = False, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ELBO loss function assuming the latent and reconstruction distributions are Gaussian. The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the decoder outputs the mean of a Gaussian distribution with variance 1), and the KL divergence between two normal distributions ``N(mean, var)`` and ``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder. Parameters ---------- target : torch.Tensor Shape ``(n_batches, in_features)``. Data points (e.g. input of encoder or time-lagged features). output : torch.Tensor Shape ``(n_batches, in_features)``. Output of the decoder. mean : torch.Tensor Shape ``(n_batches, latent_features)``. The means of the Gaussian distributions associated to the inputs. log_variance : torch.Tensor Shape ``(n_batches, latent_features)``. The logarithm of the variances of the Gaussian distributions associated to the inputs. beta : float, optional A scaling factor for the KL divergence term. The default is 1.0, which means that the KL divergence is not scaled. If set to a value greater than 1, it will increase the weight of the KL divergence term in the loss function (useful to increase regularization). If set to a value less than 1, it will decrease the weight of the KL divergence term (useful to avoid posterior collapse). return_loss_terms : bool, optional If ``True``, besides to total loss, return the two main terms of the ELBO separately (reconstruction loss and KL divergence). The default is ``False``, which returns just the total loss. weights : torch.Tensor, optional Shape ``(n_batches,)`` or ``(n_batches,1)``. If given, the average over batches is weighted. The default (``None``) is unweighted. Returns ------- loss: torch.Tensor The value of the loss function. """ # KL divergence between N(mean, variance) and N(0, 1). # See https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians kl = -0.5 * (log_variance - log_variance.exp() - mean**2 + 1).sum(dim=1) # Weighted mean over batches. if weights is None: kl = kl.mean() else: weights = weights.squeeze() if weights.shape != kl.shape: raise ValueError( f"weights should be a tensor of shape (n_batches,) or (n_batches,1), not {weights.shape}." ) kl = (kl * weights).sum() # Reconstruction loss. reconstruction = mse_loss(output, target, weights=weights) loss = reconstruction + beta*kl if return_loss_terms: return loss, reconstruction, kl else: return loss