Source code for mlcolvar.core.loss.fisher

#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Fisher discriminant loss for (Deep) Linear Discriminant Analysis.
"""

__all__ = ["FisherDiscriminantLoss", "fisher_discriminant_loss"]


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Optional

import torch

from mlcolvar.core.stats import LDA
from mlcolvar.core.loss import reduce_eigenvalues_loss


# =============================================================================
# LOSS FUNCTIONS
# =============================================================================


[docs] class FisherDiscriminantLoss(torch.nn.Module): """Fisher's discriminant ratio. Computes the sum (or another reducing functions) of the eigenvalues of the ratio between the Fisher's scatter matrices. This is the same loss function used in :class:`~mlcolvar.cvs.supervised.deeplda.DeepLDA`. """
[docs] def __init__( self, n_states: int, lda_mode: str = "standard", reduce_mode: str = "sum", lorentzian_reg: Optional[float] = None, invert_sign: bool = True, ): """Constructor. Parameters ---------- n_states : int The number of states. Labels are in the range ``[0, n_states-1]``. lda_mode : str Either ``'standard'`` or ``'harmonic'``. This determines how the scatter matrices are computed (see also :class:`~mlcolvar.core.stats.lda.LDA`). The default is ``'standard'``. reduce_mode : str This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2`` (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The default is ``'sum'``. lorentzian_reg: float, optional The magnitude of the regularization for Lorentzian regularization. If not provided, this is automatically set. invert_sign: bool, optional Whether to return the negative Fisher's discriminant ratio in order to be minimized with gradient descent methods. Default is ``True``. """ super().__init__() self.n_states = n_states self.lda_mode = lda_mode self.reduce_mode = reduce_mode self.lorentzian_reg = lorentzian_reg self.invert_sign = invert_sign
[docs] def forward( self, x: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: """Compute the value of the loss function. Parameters ---------- x : torch.Tensor Shape ``(n_batches, n_features)``. Input features. labels : torch.Tensor Shape ``(n_batches,)``. Classes labels. Returns ------- loss : torch.Tensor Loss value. """ return fisher_discriminant_loss( x, labels, n_states=self.n_states, lda_mode=self.lda_mode, reduce_mode=self.reduce_mode, lorentzian_reg=self.lorentzian_reg, invert_sign=self.invert_sign, )
def fisher_discriminant_loss( x: torch.Tensor, labels: torch.Tensor, n_states: int, lda_mode: str = "standard", reduce_mode: str = "sum", sw_reg: Optional[float] = 0.05, lorentzian_reg: Optional[float] = None, invert_sign: bool = True, ) -> torch.Tensor: """Fisher's discriminant ratio. Computes the sum (or another reducing functions) of the eigenvalues of the ratio between the Fisher's scatter matrices with a Lorentzian regularization. This is the same loss function used in :class:`~mlcolvar.cvs.supervised.deeplda.DeepLDA`. Parameters ---------- x : torch.Tensor Shape ``(n_batches, n_features)``. Input features. labels : torch.Tensor Shape ``(n_batches,)``. Classes labels. n_states : int The number of states. Labels are in the range ``[0, n_states-1]``. lda_mode : str, optional Either ``'standard'`` or ``'harmonic'``. This determines how the scatter matrices are computed (see also :class:`~mlcolvar.core.stats.lda.LDA`). The default is ``'standard'``. reduce_mode : str, optional This determines how the eigenvalues are reduced, e.g., ``sum``, ``sum2`` (see also :class:`~mlcolvar.core.loss.eigvals.ReduceEigenvaluesLoss`). The default is ``'sum'``. sw_reg: float, optional The magnitude of the regularization for the within-scatter matrix, by default equal to 0.05. lorentzian_reg: float, optional The magnitude of the regularization for Lorentzian regularization. If not provided, this is automatically set according to sw_reg. invert_sign: bool, optional Whether to return the negative Fisher's discriminant ratio in order to be minimized with gradient descent methods. Default is ``True``. Returns ------- loss: torch.Tensor Loss value. """ # define lda object lda = LDA(in_features=x.shape[-1], n_states=n_states, mode=lda_mode) # regularize s_w lda.sw_reg = sw_reg # compute LDA eigvals eigvals, _ = lda.compute(x, labels) loss = reduce_eigenvalues_loss(eigvals, mode=reduce_mode, invert_sign=invert_sign) # Add lorentzian regularization. The heuristic is the same used by DeepLDA. # TODO: ENCAPSULATE THIS IN A UTILITY FUNCTION USED BY BOTH THIS AND DEEPLDA? if lorentzian_reg is None: if sw_reg == 0 or sw_reg is None: raise ValueError( f"Unable to calculate `lorentzian_reg` from `sw_reg` ({sw_reg}), please specify the value." ) lorentzian_reg = 2.0 / sw_reg reg_loss = x.pow(2).sum().div(x.size(0)) reg_loss = -lorentzian_reg / (1 + (reg_loss - 1).pow(2)) return loss + reg_loss