Source code for mlcolvar.core.loss.mse
#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
(Weighted) Mean Squared Error (MSE) loss function.
"""
__all__ = ["MSELoss", "mse_loss"]
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
from typing import Optional
import torch
# =============================================================================
# LOSS FUNCTIONS
# =============================================================================
[docs]
class MSELoss(torch.nn.Module):
"""(Weighted) Mean Square Error"""
[docs]
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute the value of the loss function."""
return mse_loss(input, target, weights)
def mse_loss(
input: torch.Tensor, target: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""(Weighted) Mean Square Error
Parameters
----------
input : torch.Tensor
prediction
target : torch.Tensor
reference
weights : torch.Tensor, optional
sample weights, by default None
Returns
-------
loss: torch.Tensor
loss function
"""
# reshape in the correct format (batch, size)
if input.ndim == 1:
input = input.unsqueeze(1)
if target.ndim == 1:
target = target.unsqueeze(1)
# take the different
diff = input - target
# weight them
if weights is not None:
if weights.ndim == 1:
weights = weights.unsqueeze(1)
loss = (diff * weights).square().mean()
else:
loss = diff.square().mean()
return loss