mlcolvar.cvs.DeepTICA¶
- class mlcolvar.cvs.DeepTICA(model: List[int] | FeedForward | BaseGNN, n_cvs: int = None, options: dict = None, **kwargs)[source]¶
Bases:
BaseCVNeural network-based time-lagged independent component analysis (Deep-TICA).
It is a non-linear generalization of TICA in which a feature map is learned by a neural network optimized as to maximize the eigenvalues of the transfer operator, approximated by TICA. The method is described in [1]. Note that from the point of view of the architecture DeepTICA is similar to the SRV [2] method.
- Data: for training it requires a DictDataset containing:
If using descriptors as input, the keys ‘data’ (input at time t)
and ‘data_lag’ (input at time t+lag), as well as the corresponding ‘weights’ and ‘weights_lag’ which will be used to weight the time correlation functions. - If using graphs as input, the keys ‘data_list’ and ‘data_list_lag’, each containing the respective ‘weight’
This can be created in both cases with the helper function create_timelagged_dataset.
Loss: maximize TICA eigenvalues (ReduceEigenvaluesLoss)
References
See also
mlcolvar.core.stats.TICATime Lagged Indipendent Component Analysis
mlcolvar.core.loss.ReduceEigenvalueLossEigenvalue reduction to a scalar quantity
mlcolvar.utils.timelagged.create_timelagged_datasetCreate dataset of time-lagged data.
- __init__(model: List[int] | FeedForward | BaseGNN, n_cvs: int = None, options: dict = None, **kwargs)[source]¶
Define a Deep-TICA CV, composed of a neural network module and a TICA object. By default a module standardizing the inputs is also used.
- Parameters:
model (list or FeedForward or BaseGNN) – Determines the underlying machine-learning model. One can pass: 1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN.
The model Will be automatically intialized using a mlcolvar.core.nn.feedforward.FeedForward object. The CV class will be initialized according to the DEFAULT_BLOCKS.
An externally intialized model (either mlcolvar.core.nn.feedforward.FeedForward or mlcolvar.core.nn.graph.BaseGNN object). The CV class will be initialized according to the MODEL_BLOCKS.
n_cvs (int, optional) – Number of cvs to optimize, default None (= last layer)
options (dict[str, Any], optional) – Options for the building blocks of the model, by default {}. Available blocks: [‘norm_in’,’nn’,’tica’]. Set ‘block_name’ = None or False to turn off that block
Methods
__init__(model[, n_cvs, options])Define a Deep-TICA CV, composed of a neural network module and a TICA object.
forward_nn(x)set_regularization([c0_reg])Add identity matrix multiplied by c0_reg to correlation matrix C(0) to avoid instabilities in performin Cholesky and .
training_step(train_batch, batch_idx)Compute and return the training loss and record metrics.
- configure_model() None¶
Hook to create modules in a strategy and precision aware context.
This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we’d like to shard the model instantly to save memory and initialization time. For non-sharded strategies, you can choose to override this hook or to initialize your model under the
init_module()context manager.This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op.
- configure_optimizers()¶
Initialize the optimizer based on self._optimizer_name and self.optimizer_kwargs. It also adds the learning rate scheduler if self.lr_scheduler_kwargs is not empty. The scheduler is given as a dictionary with the key ‘scheduler’ containing the scheduler class and the rest of the keys are config options for the scheduler.
- Returns:
torch.optim – Torch optimizer
dict, optional – Learning rate scheduler configuration (if any)
- configure_sharded_model() None¶
Deprecated.
Use
configure_model()instead.
- cpu() Self¶
See
torch.nn.Module.cpu().
- cuda(device: device | int | None = None) Self¶
Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.
- Arguments:
- device: If specified, all parameters will be copied to that device. If None, the current CUDA device
index will be used.
- Returns:
Module: self
- double() Self¶
See
torch.nn.Module.double().
- property example_input_array¶
The example input array is a specification of what the module can consume in the
forward()method. The return type is interpreted as follows:Single tensor: It is assumed the model takes a single argument, i.e.,
model.forward(model.example_input_array)Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
model.forward(*model.example_input_array)Dict: The input array represents named keyword arguments, i.e.,
model.forward(**model.example_input_array)
- float() Self¶
See
torch.nn.Module.float().
- forward(x: Tensor, cell=None) Tensor¶
Evaluation of the CV
Apply preprocessing if any
Execute sequentially all the blocks in self.BLOCKS unless they are not initialized
Apply postprocessing if any
- Parameters:
x (torch.Tensor) – Input of the forward operation of the model
- Returns:
Output of the forward operation of the model
- Return type:
torch.Tensor
- forward_cv(x: Tensor) Tensor¶
Execute sequentially all the blocks in self.BLOCKS unless they are not initialized.
No pre/post processing will be executed here. This is supposed to be called during training/validation and to be overloaded if necessary.
- Parameters:
x (torch.Tensor) – Input of the forward operation of the model
- Returns:
Output of the forward operation of the model
- Return type:
torch.Tensor
- half() Self¶
See
torch.nn.Module.half().
- property hparams: AttributeDict | MutableMapping¶
The collection of hyperparameters saved with
save_hyperparameters(). It is mutable by the user. For the frozen set of initial hyperparameters, usehparams_initial.- Returns:
Mutable hyperparameters dictionary
- property hparams_initial: AttributeDict¶
The collection of hyperparameters saved with
save_hyperparameters(). These contents are read-only. Manual updates to the saved hyperparameters can instead be performed throughhparams.- Returns:
AttributeDict: immutable initial hyperparameters
- initialize_blocks()¶
Initialize the blocks as attributes of the CV class.
- property n_cvs¶
Number of CVs.
- on_after_backward() None¶
Called after
loss.backward()and before optimizers are stepped.- Note:
If using native AMP, the gradients will not be unscaled at this point. Use the
on_before_optimizer_stepif you need the unscaled gradients.
- on_after_batch_transfer(batch: Any, dataloader_idx: int) Any¶
Override to alter or apply batch augmentations to your batch after it is transferred to the device.
- Note:
To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.- Args:
batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs.
- Returns:
A batch of data
Example:
def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch
- on_before_backward(loss: Tensor) None¶
Called before
loss.backward().- Args:
loss: Loss divided by number of batches for gradient accumulation and scaled if using AMP.
- on_before_batch_transfer(batch: Any, dataloader_idx: int) Any¶
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
- Note:
To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.- Args:
batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs.
- Returns:
A batch of data
Example:
def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch
- on_before_optimizer_step(optimizer: Optimizer) None¶
Called before
optimizer.step().If using gradient accumulation, the hook is called once the gradients have been accumulated. See: :paramref:`~lightning.pytorch.trainer.trainer.Trainer.accumulate_grad_batches`.
If using AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.
If clipping gradients, the gradients will not have been clipped yet.
- Args:
optimizer: Current optimizer being used.
Example:
def on_before_optimizer_step(self, optimizer): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step )
- on_before_zero_grad(optimizer: Optimizer) None¶
Called after
training_step()and beforeoptimizer.zero_grad().Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called:
for optimizer in optimizers: out = training_step(...) model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() backward()
- Args:
optimizer: The optimizer for which grads should be zeroed.
- on_fit_end() None¶
Called at the very end of fit.
If on DDP it is called on every process
- on_fit_start()¶
Called at the very beginning of fit.
If on DDP it is called on every process
- on_load_checkpoint(checkpoint: dict[str, Any]) None¶
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()this is your chance to restore this.- Args:
checkpoint: Loaded checkpoint
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
- Note:
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- on_predict_batch_end(outputs: Any | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None¶
Called in the predict loop after the batch.
- Args:
outputs: The outputs of predict_step(x) batch: The batched data as it is returned by the prediction DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader
- on_predict_batch_start(batch: Any, batch_idx: int, dataloader_idx: int = 0) None¶
Called in the predict loop before anything happens for that batch.
- Args:
batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader
- on_predict_end() None¶
Called at the end of predicting.
- on_predict_epoch_end() None¶
Called at the end of predicting.
- on_predict_epoch_start() None¶
Called at the beginning of predicting.
- on_predict_model_eval() None¶
Called when the predict loop starts.
The predict loop by default calls
.eval()on the LightningModule before it starts. Override this hook to change the behavior.
- on_predict_start() None¶
Called at the beginning of predicting.
- on_save_checkpoint(checkpoint: dict[str, Any]) None¶
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Args:
- checkpoint: The full checkpoint dictionary before it gets dumped to a file.
Implementations of this hook can insert additional data into this dictionary.
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
- Note:
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_test_batch_end(outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None¶
Called in the test loop after the batch.
- Args:
outputs: The outputs of test_step(x) batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader
- on_test_batch_start(batch: Any, batch_idx: int, dataloader_idx: int = 0) None¶
Called in the test loop before anything happens for that batch.
- Args:
batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader
- on_test_end() None¶
Called at the end of testing.
- on_test_epoch_end() None¶
Called in the test loop at the very end of the epoch.
- on_test_epoch_start() None¶
Called in the test loop at the very beginning of the epoch.
- on_test_model_eval() None¶
Called when the test loop starts.
The test loop by default calls
.eval()on the LightningModule before it starts. Override this hook to change the behavior. See alsoon_test_model_train().
- on_test_model_train() None¶
Called when the test loop ends.
The test loop by default restores the training mode of the LightningModule to what it was before starting testing. Override this hook to change the behavior. See also
on_test_model_eval().
- on_test_start() None¶
Called at the beginning of testing.
- on_train_batch_end(outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int) None¶
Called in the training loop after the batch.
- Args:
outputs: The outputs of training_step(x) batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch
- Note:
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- on_train_batch_start(batch: Any, batch_idx: int) int | None¶
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch. Learning rate scheduler will still be stepped at the end of epoch.
- Args:
batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch
- on_train_end() None¶
Called at the end of training before logger experiment is closed.
- on_train_epoch_end() None¶
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- on_train_epoch_start() None¶
Called in the training loop at the very beginning of the epoch.
- on_train_start() None¶
Called at the beginning of training after sanity check.
- on_validation_batch_end(outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None¶
Called in the validation loop after the batch.
- Args:
outputs: The outputs of validation_step(x) batch: The batched data as it is returned by the validation DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader
- on_validation_batch_start(batch: Any, batch_idx: int, dataloader_idx: int = 0) None¶
Called in the validation loop before anything happens for that batch.
- Args:
batch: The batched data as it is returned by the validation DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader
- on_validation_end() None¶
Called at the end of validation.
- on_validation_epoch_end() None¶
Called in the validation loop at the very end of the epoch.
- on_validation_epoch_start() None¶
Called in the validation loop at the very beginning of the epoch.
- on_validation_model_eval() None¶
Called when the validation loop starts.
The validation loop by default calls
.eval()on the LightningModule before it starts. Override this hook to change the behavior. See alsoon_validation_model_train().
- on_validation_model_train() None¶
Called when the validation loop ends.
The validation loop by default restores the training mode of the LightningModule to what it was before starting validation. Override this hook to change the behavior. See also
on_validation_model_eval().
- on_validation_model_zero_grad() None¶
Called by the training loop to release gradients before entering the validation loop.
- on_validation_start() None¶
Called at the beginning of validation.
- property optimizer_name: str¶
Optimizer name. Options can be set using optimizer_kwargs. Actual optimizer will be return during training from configure_optimizer function.
- parse_options(options: dict = None)¶
Sanitize options and create defaults ({}) if not in options. Furthermore, it sets the optimizer kwargs, if given.
- Parameters:
options (dict[str, Any], optional) – Options for the building blocks of the model, by default None.
- predict_dataloader() Any¶
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data().predict()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return:
A
torch.utils.data.DataLoaderor a sequence of them specifying prediction samples.
- prepare_data() None¶
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
Warning
DO NOT set state to the model (use
setupinstead) since this is NOT called on every deviceExample:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()
In a distributed environment,
prepare_datacan be called in two ways (using prepare_data_per_node)Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = True # call on GLOBAL_RANK=0 (great for shared file systems) class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = False
This is called before requesting the dataloaders:
model.prepare_data() initialize_distributed() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader() model.predict_dataloader()
- remove_ignored_hparams(ignore_list: list[str]) None¶
Remove ignored hyperparameters from the stored state.
This allows derived classes to drop hyperparameters previously saved by base classes.
- Args:
ignore_list: Names of hyperparameters to remove.
- save_hyperparameters(*args: Any, ignore: Sequence[str] | str | None = None, frame: FrameType | None = None, logger: bool = True) None¶
Save arguments to
hparamsattribute.- Args:
- args: single object of dict, NameSpace or OmegaConf
or string names or arguments from class
__init__- ignore: an argument name or a list of argument names from
class
__init__to be ignored
frame: a frame object. Default is None logger: Whether to send the hyperparameters to the logger. Default: True
- Example::
>>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # manually assign arguments ... self.save_hyperparameters('arg1', 'arg3') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class AutomaticArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # equivalent automatic ... self.save_hyperparameters() ... def forward(self, *args, **kwargs): ... ... >>> model = AutomaticArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg2": abc "arg3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class SingleArgModel(HyperparametersMixin): ... def __init__(self, params): ... super().__init__() ... # manually assign single argument ... self.save_hyperparameters(params) ... def forward(self, *args, **kwargs): ... ... >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) >>> model.hparams "p1": 1 "p2": abc "p3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # pass argument(s) to ignore as a string or in a list ... self.save_hyperparameters(ignore='arg2') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg3": 3.14
- set_regularization(c0_reg=1e-06)[source]¶
Add identity matrix multiplied by c0_reg to correlation matrix C(0) to avoid instabilities in performin Cholesky and .
- Parameters:
c0_reg (float) – Regularization value for C_0.
- setup(stage=None)¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Args:
stage: either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- teardown(stage: str) None¶
Called at the end of fit (train + validate), validate, test, or predict.
- Args:
stage: either
'fit','validate','test', or'predict'
- test_dataloader() Any¶
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Note:
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- test_step(test_batch, batch_idx)¶
Equal to training step if not overridden. Different behaviors for train/valid step can be enforced in training_step() based on the self.training variable.
- to(*args: Any, **kwargs: Any) Self¶
See
torch.nn.Module.to().
- to_torchscript(file_path: str | Path | None = None, method: str | None = 'script', example_inputs: Any | None = None, **kwargs: Any) ScriptModule | Dict[str, ScriptModule]¶
By default compiles the whole model to a torch.jit.ScriptModule Tracing can be used with the argument method=’trace’. In case, you can provide and example_inputs, otherwise, the default example_input_array will be used.
- Args:
file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript’s script or trace method. Default: ‘script’ example_inputs: An input to be used to do tracing when method is set to ‘trace’.
Default: None (uses
example_input_array)- **kwargs: Additional arguments that will be passed to the
torch.jit.script()or torch.jit.trace()function.
- **kwargs: Additional arguments that will be passed to the
- Return:
This LightningModule as a torchscript, regardless of whether file_path is defined or not.
- train_dataloader() Any¶
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- training_step(train_batch, batch_idx)[source]¶
Compute and return the training loss and record metrics. 1) Calculate the NN output 2) Remove average (inside forward_nn) 3) Compute TICA
- transfer_batch_to_device(batch: Any, device: device, dataloader_idx: int) Any¶
Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.The data types listed below (and any arbitrary nesting of them) are supported out of the box:
torch.Tensoror anything that implements .to(…)listdicttuple
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).
- Note:
This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.- Args:
batch: A batch of data that needs to be transferred to a new device. device: The target device as defined in PyTorch. dataloader_idx: The index of the dataloader to which the batch belongs.
- Returns:
A reference to the data on the new device.
Example:
def transfer_batch_to_device(self, batch, device, dataloader_idx): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) elif dataloader_idx == 0: # skip device transfer for the first dataloader or anything you wish pass else: batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch
- See Also:
move_data_to_device()apply_to_collection()
- type(dst_type: str | dtype) Self¶
See
torch.nn.Module.type().
- val_dataloader() Any¶
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Note:
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- validation_step(val_batch, batch_idx)¶
Equal to training step if not overridden. Different behaviors for train/valid step can be enforced in training_step() based on the self.training variable.
Attributes
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
DEFAULT_BLOCKS
MODEL_BLOCKS
T_destination
automatic_optimizationIf set to
Falseyou are responsible for calling.backward(),.step(),.zero_grad().
call_super_init
current_epochThe current epoch in the
Trainer, or 0 if not attached.
device
device_meshStrategies like
ModelParallelStrategywill create a device mesh that can be accessed in theconfigure_model()hook to parallelize the LightningModule.
dtype
dump_patchesThe example input array is a specification of what the module can consume in the
forward()method.
fabric
global_rankThe index of the current process across all nodes and devices.
global_stepTotal training batches seen across all epochs.
The collection of hyperparameters saved with
save_hyperparameters().The collection of hyperparameters saved with
save_hyperparameters().
local_rankThe index of the current process within a single node.
loggerReference to the logger object in the Trainer.
loggersReference to the list of loggers in the Trainer.
Number of CVs.
on_gpuReturns
Trueif this model is currently located on a GPU.Optimizer name.
strict_loadingDetermines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).
trainer
training