#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Target Discriminant Analysis Loss Function.
"""
__all__ = ["TDALoss", "tda_loss"]
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
from typing import Union, List, Tuple
from warnings import warn
import torch
# =============================================================================
# LOSS FUNCTIONS
# =============================================================================
[docs]
class TDALoss(torch.nn.Module):
"""Compute a loss function as the distance from a simple Gaussian target distribution."""
[docs]
def __init__(
self,
n_states: int,
target_centers: Union[List[float], torch.Tensor],
target_sigmas: Union[List[float], torch.Tensor],
alpha: float = 1.0,
beta: float = 100.0,
):
"""Constructor.
Parameters
----------
n_states : int
Number of states. The integer labels are expected to be in between 0
and ``n_states-1``.
target_centers : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
target_sigmas : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
alpha : float, optional
Centers_loss component prefactor, by default 1.
beta : float, optional
Sigmas loss compontent prefactor, by default 100.
"""
super().__init__()
self.n_states = n_states
if not isinstance(target_centers, torch.Tensor):
target_centers = torch.Tensor(target_centers)
if not isinstance(target_sigmas, torch.Tensor):
target_sigmas = torch.Tensor(target_sigmas)
self.register_buffer("target_centers", target_centers)
self.register_buffer("target_sigmas", target_sigmas)
self.alpha = alpha
self.beta = beta
[docs]
def forward(
self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Compute the value of the loss function.
Parameters
----------
H : torch.Tensor
Shape ``(n_batches, n_features)``. Output of the NN.
labels : torch.Tensor
Shape ``(n_batches,)``. Labels of the dataset.
return_loss_terms : bool, optional
If ``True``, the loss terms associated to the center and standard
deviations of the target Gaussians are returned as well. Default
is ``False``.
Returns
-------
loss : torch.Tensor
Loss value.
loss_centers : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the
loss term associated to the centers of the target Gaussians.
loss_sigmas : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the
loss term associated to the standard deviations of the target Gaussians.
"""
return tda_loss(
H,
labels,
self.n_states,
self.target_centers,
self.target_sigmas,
self.alpha,
self.beta,
return_loss_terms,
)
def tda_loss(
H: torch.Tensor,
labels: torch.Tensor,
n_states: int,
target_centers: Union[List[float], torch.Tensor],
target_sigmas: Union[List[float], torch.Tensor],
alpha: float = 1,
beta: float = 100,
return_loss_terms: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Compute a loss function as the distance from a simple Gaussian target distribution.
Parameters
----------
H : torch.Tensor
Shape ``(n_batches, n_cvs)``. Output of the NN.
labels : torch.Tensor
Shape ``(n_batches,)``. Labels of the dataset.
n_states : int
The integer labels are expected to be in between 0 and ``n_states-1``.
target_centers : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Centers of the Gaussian targets.
target_sigmas : list or torch.Tensor
Shape ``(n_states, n_cvs)``. Standard deviations of the Gaussian targets.
alpha : float, optional
Centers_loss component prefactor, by default 1.
beta : float, optional
Sigmas loss compontent prefactor, by default 100.
return_loss_terms : bool, optional
If ``True``, the loss terms associated to the center and standard deviations
of the target Gaussians are returned as well. Default is ``False``.
Returns
-------
loss : torch.Tensor
Loss value.
loss_centers : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the loss
term associated to the centers of the target Gaussians.
loss_sigmas : torch.Tensor, optional
Only returned if ``return_loss_terms is True``. The value of the loss
term associated to the standard deviations of the target Gaussians.
"""
if not isinstance(target_centers, torch.Tensor):
target_centers = torch.tensor(target_centers, dtype=H.dtype)
if not isinstance(target_sigmas, torch.Tensor):
target_sigmas = torch.tensor(target_sigmas, dtype=H.dtype)
device = H.device
target_centers = target_centers.to(device)
target_sigmas = target_sigmas.to(device)
loss_centers = torch.zeros_like(target_centers, device=device)
loss_sigmas = torch.zeros_like(target_sigmas, device=device)
for i in range(n_states):
# check which elements belong to class i
if not (labels == i).any():
raise ValueError(
f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!"
)
else:
H_red = H[labels == i]
# compute mean and standard deviation over the class i
mu = torch.mean(H_red, 0)
if len(torch.nonzero(labels == i)) == 1:
warn(
f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!"
)
sigma = torch.Tensor(0)
else:
sigma = torch.std(H_red, 0)
# compute loss function contributes for class i
loss_centers[i] = alpha * (mu - target_centers[i]).pow(2)
loss_sigmas[i] = beta * (sigma - target_sigmas[i]).pow(2)
# get total model loss
loss_centers = torch.sum(loss_centers)
loss_sigmas = torch.sum(loss_sigmas)
loss = loss_centers + loss_sigmas
if return_loss_terms:
return loss, loss_centers, loss_sigmas
return loss
def test_tda_loss():
H = torch.randn(100)
H.requires_grad = True
labels = torch.zeros_like(H)
labels[-50:] = 1
Loss = TDALoss(n_states=2, target_centers=[-1, 1], target_sigmas=[0.1, 0.1])
loss = Loss(H=H, labels=labels, return_loss_terms=True)
loss[0].backward()