"""Linear discriminant analysis"""
import torch
from .utils import cholesky_eigh
from mlcolvar.core.stats import Stats
__all__ = ["LDA"]
[docs]
class LDA(Stats):
"""
Fisher's discriminant class.
Attributes
----------
in_features : int
Number of input features
out_features : int
Number of ouput features (n_states - 1)
n_states : int
Number of states
evals : torch.Tensor
LDA eigenvalues
evecs : torch.Tensor
LDA eigenvectors
S_b : torch.Tensor
Between scatter matrix
S_w : torch.Tensor
Within scatter matrix
sw_reg : float
Regularization to S_w matrix
"""
[docs]
def __init__(self, in_features: int, n_states: int, mode: str = "standard"):
"""
Initialize a LDA object.
Parameters
----------
in_features : int
number of input features
n_states : int
number of states
mode : string
options: {'standard','harmonic'}, standard or harmonic variant
"""
super().__init__()
# Save attributes
self.n_states = n_states
self.in_features = in_features
self.out_features = n_states - 1
if (mode == "standard") or (mode == "harmonic"):
self.mode = mode
else:
raise ValueError(f"LDA mode should be standard or harmonic, not {mode}.")
# create eigenvector buffer
self.register_buffer("evecs", torch.eye(in_features, self.out_features))
# initialize other attributes
self.evals = None
self.S_b = None
self.S_w = None
# Regularization
self.sw_reg = 1e-6
[docs]
def compute(self, X, labels, save_params=True):
"""
Compute LDA eigenvalues and eigenvectors.
First compute the scatter matrices S_between (S_b) and S_within (S_w) and then solve the generalized eigenvalue problem.
Parameters
----------
X : array-like of shape (n_samples, in_features)
Training data.
labels : array-like of shape (n_samples,)
states labels.
save_params: bool, optional
Whether to store parameters in model
Returns
-------
eigvals : torch.Tensor
LDA eigenvalues (n_states-1)
eigvecs : torch.Tensor
LDA eigenvectors (n_feature,n_states-1)
Notes
-----
The eigenvecs object which is returned is a matrix whose column eigvecs[:,i] is the eigenvector associated to eigvals[i]
"""
S_b, S_w = self.compute_scatter_matrices(X, labels, save_params)
evals, evecs = cholesky_eigh(S_b, S_w, self.sw_reg, n_eig=self.n_states - 1)
if save_params:
self.evals = evals
self.evecs = evecs
return evals, evecs
[docs]
def compute_scatter_matrices(self, X, labels, save_params=True):
"""
Compute between scatter and within scatter matrices.
Parameters
----------
X : array-like of shape (n_samples, in_features)
Training data.
labels : array-like of shape (n_samples,)
states labels.
save_params: bool, optional
Whether to store parameters in model
Returns
-------
S_b,S_w : arrays of shape (in_features,in_features)
Between and within scatter matrices
"""
# device
device = X.device
# sizes
N, d = X.shape
self.in_features = d
# states
states = torch.unique(labels)
n_states = len(states)
# mean centered observations for entire population
X_bar = X - torch.mean(X, 0, True)
# Total scatter matrix (cov matrix over all observations)
S_t = X_bar.t().matmul(X_bar) / (N - 1)
# Define within scatter matrix and compute it
S_w = torch.Tensor().new_zeros((d, d)).to(device)
if self.mode == "harmonic":
S_w_inv = torch.Tensor().new_zeros((d, d)).to(device)
# Loop over states to compute means and covs
for i in states:
# check which elements belong to class i
X_i = X[torch.nonzero(labels == i).view(-1)]
# compute mean centered obs of class i
X_i_bar = X_i - torch.mean(X_i, 0, True)
# count number of elements
N_i = X_i.shape[0]
if N_i == 0:
continue
# LDA
S_w += X_i_bar.t().matmul(X_i_bar) / ((N_i - 1) * n_states)
# HLDA
if self.mode == "harmonic":
inv_i = X_i_bar.t().matmul(X_i_bar) / ((N_i - 1) * n_states)
S_w_inv += inv_i.inverse()
if self.mode == "harmonic":
S_w = S_w_inv.inverse()
# Compute S_b from total scatter matrix
S_b = S_t - S_w
if save_params:
self.S_b = S_b
self.S_w = S_w
return S_b, S_w
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute linear combination with saved eigenvectors
Parameters
----------
x: torch.Tensor
input
Returns
-------
out : torch.Tensor
output
"""
return torch.matmul(x, self.evecs)
def test_lda():
in_features = 2
n_states = 2
torch.manual_seed(42)
X = torch.rand(100, in_features) * 100
y = torch.randint(n_states, (100, 1)).squeeze(1)
# standard
lda = LDA(in_features, n_states)
print(lda)
S_b, S_w = lda.compute_scatter_matrices(X, y)
print(S_w, S_b)
evals, evecs = lda.compute(X, y, True)
print(lda.S_w.shape, lda.S_b.shape)
print(evals.shape, evecs.shape)
s = lda(X)
print(s.shape)
assert (s.ndim == 2) and (s.shape[1] == n_states - 1)
# harmonic variant
hlda = LDA(in_features, n_states, mode="harmonic")
print(hlda)
hlda.compute(X, y)
s = lda(X)