Source code for mlcolvar.cvs.unsupervised.vae

#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Variational Autoencoder collective variable.
"""

__all__ = ["VariationalAutoEncoderCV"]

# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Any, Optional, Tuple
import torch
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward, Normalization
from mlcolvar.core.transform.utils import Inverse
from mlcolvar.core.loss import ELBOGaussiansLoss


# =============================================================================
# VARIATIONAL AUTOENCODER CV
# =============================================================================


[docs] class VariationalAutoEncoderCV(BaseCV): """Variational AutoEncoder Collective Variable. At training time, the encoder outputs a mean and a variance for each CV defining a Gaussian distribution associated to the input. One sample is drawn from this Gaussian, and it goes through the decoder. Then the ELBO loss is minimized [1]_. The ELBO sums the MSE of the reconstruction and the KL divergence between the generated Gaussian and a N(0, 1) Gaussian multiplied by a beta factor. The beta factor can be used to control the weight of the KL divergence term in the loss function [2]_. At evaluation time, the encoder's output mean is used as the CV, while the variance output and the decoder are ignored. **Data**: for training, it requires a DictDataset with the key ``'data'`` and optionally ``'weights'``. If a 'target' key is present this will be used as reference for the output of the decoder, otherway this will be compared with the input 'data'. This feature can be used to train (variational) time-lagged autoencoders like in [3]_. **Loss**: Evidence Lower BOund (ELBO) References ---------- .. [1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." 20 Dec. 2013. .. [2] Higgins, Irina, Loic Matthey, Arka Pal, Christopher Burgess, Xavier Glorot, Matthew Botvinick, Shakir Mohamed, and Alexander Lerchner. “β-VAE: Learning basic visual concepts with a constrained variational framework.” International Conference on Learning Representations (ICLR), 2017. .. [3] C. X. Hernández, H. K. Wayment-Steele, M. M. Sultan, B. E. Husic, and V. S. Pande, “Variational encoding of complex dynamics,” Physical Review E 97, 062412 (2018). See also -------- mlcolvar.core.loss.ELBOLoss Evidence Lower BOund loss function """ DEFAULT_BLOCKS = ["norm_in", "encoder", "decoder"]
[docs] def __init__( self, n_cvs: int, encoder_layers: list, decoder_layers: Optional[list] = None, beta: float = 1.0, options: Optional[dict] = None, **kwargs, ): """ Variational autoencoder constructor. Initializes two neural network modules (encoder and decoder). By default a module standardizing the inputs is also used. Parameters ---------- n_cvs : int The dimension of the CV or, equivalently, the dimension of the latent space of the autoencoder. encoder_layers : list Number of neurons per layer of the encoder up to the last hidden layer. The size of the output layer is instead specified with ``n_cvs`` decoder_layers : list, optional Number of neurons per layer of the decoder, except for the input layer which is specified by ``n_cvs``. If ``None`` (default), it takes automatically the reversed architecture of the encoder. beta : float, optional A scaling factor for the KL divergence term in the loss function. The default is 1.0. options : dict[str, Any], optional Options for the building blocks of the model, by default ``None``. Available blocks are: ``'norm_in'``, ``'encoder'``, and ``'decoder'``. Set ``'block_name' = None`` or ``False`` to turn off a block. Encoder and decoder cannot be turned off. """ if not isinstance(encoder_layers, list): raise NotImplementedError( f"Encoder layer must be a list. found {type(encoder_layers)}" ) super().__init__(model=encoder_layers, **kwargs) # this makes checkpointing safe, to avoid double model keys self.save_hyperparameters(ignore=['model']) self.hparams.pop('model',None) # here we need to override the self.out_features attribute self.out_features = n_cvs # ======= LOSS ======= # ELBO loss function when latent space and reconstruction distributions are Gaussians. self.loss_fn = ELBOGaussiansLoss() self.beta = beta # ======= OPTIONS ======= # parse and sanitize options = self.parse_options(options) # if decoder is not given reverse the encoder if decoder_layers is None: decoder_layers = encoder_layers[::-1] # ======= BLOCKS ======= # initialize norm_in o = "norm_in" if (options[o] is not False) and (options[o] is not None): self.norm_in = Normalization(self.in_features, **options[o]) # initialize encoder # The encoder outputs two values for each CV representing mean and std. o = "encoder" # Note: The FeedForward implementing the encoder by default needs to have also # the nonlinearity (and eventually dropout/batchnorm) also for the output # layer since we'll have two separate linear layers for the mean and variance. if "last_layer_activation" not in options[o]: options[o]["last_layer_activation"] = True self.encoder = FeedForward(encoder_layers, **options[o]) self.mean_nn = torch.nn.Linear( in_features=encoder_layers[-1], out_features=n_cvs ) self.log_var_nn = torch.nn.Linear( in_features=encoder_layers[-1], out_features=n_cvs ) # initialize encoder o = "decoder" self.decoder = FeedForward([n_cvs] + decoder_layers, **options[o])
[docs] def forward_cv(self, x: torch.Tensor) -> torch.Tensor: """Compute the value of the CV from preprocessed input. Return the mean output (ignoring the variance output) of the encoder after (optionally) applying the normalization to the input. Parameters ---------- x : torch.Tensor Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The input descriptors of the CV after preprocessing. Returns ------- cv : torch.Tensor Shape ``(n_batches, n_cvs)``. The CVs, i.e., the mean output of the encoder (the variance output is discarded). """ if self.norm_in is not None: x = self._apply_module(self.norm_in, x) x = self._apply_module(self.encoder, x) # Take only the means and ignore the log variances. return self.mean_nn(x)
[docs] def encode_decode( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Run a pass of encoding + decoding. The function applies the normalizing to the inputs and its reverse on the output. Parameters ---------- x : torch.Tensor Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The input descriptors of the CV after preprocessing. Returns ------- mean : torch.Tensor Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The mean of the Gaussian distribution associated to the input in latent space. log_variance : torch.Tensor Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The logarithm of the variance of the Gaussian distribution associated to the input in latent space. x_hat : torch.Tensor Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The reconstructed descriptors. """ # Normalize inputs. if self.norm_in is not None: x = self._apply_module(self.norm_in, x) # Encode input into a Gaussian distribution. x = self._apply_module(self.encoder, x) mean, log_variance = self.mean_nn(x), self.log_var_nn(x) # Clamp the log_variance to prevent it from becoming -inf (numerical stability). log_variance = torch.clamp(log_variance, min=-15) # Sample from the Gaussian distribution in latent space. std = torch.exp(log_variance / 2) z = torch.distributions.Normal(mean, std).rsample() # Decode sample. x_hat = self._apply_module(self.decoder, z) if self.norm_in is not None: x_hat = self.norm_in.inverse(x_hat) return mean, log_variance, x_hat
[docs] def training_step(self, train_batch, batch_idx): """Single training step performed by the PyTorch Lightning Trainer.""" x = train_batch["data"] loss_kwargs = {} if "weights" in train_batch: loss_kwargs["weights"] = train_batch["weights"] # Encode/decode. mean, log_variance, x_hat = self.encode_decode(x) # Reference output (compare with a 'target' key if any, otherwise with input 'data') if "target" in train_batch: x_ref = train_batch["target"] else: x_ref = x # Loss function. loss, reconstruction_loss, kl_loss = self.loss_fn(target = x_ref, output = x_hat, mean = mean, log_variance = log_variance, beta = self.beta, return_loss_terms = True, **loss_kwargs) # Log. name = "train" if self.training else "valid" self.log(f"{name}_loss", loss, on_epoch=True) self.log(f"{name}_reconstruction_loss", reconstruction_loss, on_epoch=True) self.log(f"{name}_kl_loss", kl_loss, on_epoch=True) self.log(f"beta", self.beta, on_epoch=True) return loss
[docs] def get_decoder(self, return_normalization=False): """Return a torch model with the decoder and optionally the normalization inverse""" if return_normalization: if self.norm_in is not None: inv_norm = Inverse(module=self.norm_in) decoder_model = torch.nn.Sequential(*[self.decoder, inv_norm]) else: raise ValueError( "return_normalization is set to True but self.norm_in is None" ) else: decoder_model = self.decoder return decoder_model