trainer¶
Classes
Customize every aspect of training via flags |
Trainer to automate the training.
-
class
pytorch_lightning.trainer.trainer.
Trainer
(logger=True, checkpoint_callback=True, callbacks=None, default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, num_processes=1, gpus=None, auto_select_gpus=False, tpu_cores=None, log_gpu_memory=None, progress_bar_refresh_rate=1, overfit_batches=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, min_steps=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, val_check_interval=1.0, flush_logs_every_n_steps=100, log_every_n_steps=50, accelerator=None, sync_batchnorm=False, precision=32, weights_summary='top', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, plugins=None, amp_backend='native', amp_level='O2', distributed_backend=None, automatic_optimization=None, move_metrics_to_cpu=False, enable_pl_optimizer=None)[source]¶ Bases:
pytorch_lightning.trainer.properties.TrainerProperties
,pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin
,pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin
,pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin
,pytorch_lightning.trainer.logging.TrainerLoggingMixin
,pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
,pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin
,pytorch_lightning.trainer.deprecated_api.DeprecatedDistDeviceAttributes
Customize every aspect of training via flags
- Parameters
accelerator¶ (
Union
[str
,Accelerator
,None
]) – Previously known as distributed_backend (dp, ddp, ddp2, etc…). Can also take in an accelerator object for custom hardware.accumulate_grad_batches¶ (
Union
[int
,Dict
[int
,int
],List
[list
]]) – Accumulates grads every k batches or as set up in the dict.amp_backend¶ (
str
) – The mixed precision backend to use (“native” or “apex”)amp_level¶ (
str
) – The optimization level to use (O1, O2, etc…).auto_lr_find¶ (
Union
[bool
,str
]) – If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name.auto_scale_batch_size¶ (
Union
[str
,bool
]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search.auto_select_gpus¶ (
bool
) – If enabled and gpus is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them.callbacks¶ (
Optional
[List
[Callback
]]) – Add a list of callbacks.If
True
, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint incallbacks
. Default:True
.Warning
Passing a ModelCheckpoint instance to this argument is deprecated since v1.1 and will be unsupported from v1.3. Use callbacks argument instead.
check_val_every_n_epoch¶ (
int
) – Check val every n train epochs.default_root_dir¶ (
Optional
[str
]) – Default path for logs and weights when no logger/ckpt_callback passed. Default:os.getcwd()
. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’deterministic¶ (
bool
) – If true enables cudnn.deterministic.distributed_backend¶ (
Optional
[str
]) – deprecated. Please use ‘accelerator’fast_dev_run¶ (
Union
[int
,bool
]) – runs n if set ton
(int) else 1 if set toTrue
batch(es) of train, val and test to find any bugs (ie: a sort of unit test).flush_logs_every_n_steps¶ (
int
) – How often to flush logs to disk (defaults to every 100 steps).gpus¶ (
Union
[int
,str
,List
[int
],None
]) – number of gpus to train on (int) or which GPUs to train on (list or str) applied per nodelimit_train_batches¶ (
Union
[int
,float
]) – How much of training dataset to check (floats = percent, int = num_batches)limit_val_batches¶ (
Union
[int
,float
]) – How much of validation dataset to check (floats = percent, int = num_batches)limit_test_batches¶ (
Union
[int
,float
]) – How much of test dataset to check (floats = percent, int = num_batches)logger¶ (
Union
[LightningLoggerBase
,Iterable
[LightningLoggerBase
],bool
]) – Logger (or iterable collection of loggers) for experiment tracking.log_gpu_memory¶ (
Optional
[str
]) – None, ‘min_max’, ‘all’. Might slow performancelog_every_n_steps¶ (
int
) – How often to log within steps (defaults to every 50 steps).automatic_optimization¶ (
Optional
[bool
]) – If False you are responsible for calling .backward, .step, zero_grad in LightningModule. This argument has been moved to LightningModule. It is deprecated here in v1.1 and will be removed in v1.3.prepare_data_per_node¶ (
bool
) – If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare dataprocess_position¶ (
int
) – orders the progress bar when running multiple models on same machine.progress_bar_refresh_rate¶ (
int
) – How often to refresh progress bar (in steps). Value0
disables progress bar. Ignored when a custom callback is passed tocallbacks
.profiler¶ (
Union
[BaseProfiler
,bool
,str
,None
]) – To profile individual steps during training and assist in identifying bottlenecks. Passing bool value is deprecated in v1.1 and will be removed in v1.3.overfit_batches¶ (
Union
[int
,float
]) – Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0plugins¶ (
Union
[str
,list
,None
]) – Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.precision¶ (
int
) – Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.max_epochs¶ (
int
) – Stop training once this number of epochs is reached.min_epochs¶ (
int
) – Force training for at least these many epochsmax_steps¶ (
Optional
[int
]) – Stop training after this number of steps. Disabled by default (None).min_steps¶ (
Optional
[int
]) – Force training for at least these number of steps. Disabled by default (None).num_nodes¶ (
int
) – number of GPU nodes for distributed training.num_processes¶ (
int
) – number of processes for distributed training with distributed_backend=”ddp_cpu”num_sanity_val_steps¶ (
int
) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders. Default: 2reload_dataloaders_every_epoch¶ (
bool
) – Set to True to reload dataloaders every epoch.replace_sampler_ddp¶ (
bool
) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will addshuffle=True
for train sampler andshuffle=False
for val/test sampler. If you want to customize it, you can setreplace_sampler_ddp=False
and add your own distributed sampler.resume_from_checkpoint¶ (
Union
[str
,Path
,None
]) – Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.sync_batchnorm¶ (
bool
) – Synchronize batch norm layers between process groups/whole world.terminate_on_nan¶ (
bool
) – If set to True, will terminate training (by raising a ValueError) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf.tpu_cores¶ (
Union
[int
,str
,List
[int
],None
]) – How many TPU cores to train on (1 or 8) / Single TPU to train on [1]track_grad_norm¶ (
Union
[int
,float
,str
]) – -1 no tracking. Otherwise tracks that p-norm. May be set to ‘inf’ infinity-norm.truncated_bptt_steps¶ (
Optional
[int
]) – Truncated back prop breaks performs backprop every k steps of much longer sequence.val_check_interval¶ (
Union
[int
,float
]) – How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches).weights_summary¶ (
Optional
[str
]) – Prints a summary of the weights when training begins.weights_save_path¶ (
Optional
[str
]) – Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in default_root_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’ Defaults to default_root_dir.move_metrics_to_cpu¶ (
bool
) – Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention.enable_pl_optimizer¶ (
Optional
[bool
]) – If True, each optimizer will be wrapped by pytorch_lightning.core.optimizer.LightningOptimizer. It allows Lightning to handle AMP, TPU, accumulated_gradients, etc. .. warning:: Currently deprecated and it will be removed in v1.3
-
static
available_plugins
()[source]¶ List of all available plugins that can be string arguments to the trainer. Returns: List of all available plugins that are supported as string arguments.
-
fit
(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]¶ Runs the full optimization routine.
- Parameters
datamodule¶ (
Optional
[LightningDataModule
]) – A instance ofLightningDataModule
.model¶ (
LightningModule
) – Model to fit.train_dataloader¶ (
Optional
[DataLoader
]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders¶ (
Union
[DataLoader
,List
[DataLoader
],None
]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped
-
setup_trainer
(model)[source]¶ Sanity check a few things before starting actual training or testing.
- Parameters
model¶ (
LightningModule
) – The model to run sanity test on.
-
test
(model=None, test_dataloaders=None, ckpt_path='best', verbose=True, datamodule=None)[source]¶ Separates from fit to make sure you never run on your test set until you want to.
- Parameters
ckpt_path¶ (
Optional
[str
]) – Eitherbest
or path to the checkpoint you wish to test. IfNone
, use the weights from the last epoch to test. Default tobest
.datamodule¶ (
Optional
[LightningDataModule
]) – A instance ofLightningDataModule
.model¶ (
Optional
[LightningModule
]) – The model to test.test_dataloaders¶ (
Union
[DataLoader
,List
[DataLoader
],None
]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples.
- Returns
The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
-
tune
(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]¶ Runs routines to tune hyperparameters before training.
- Parameters
datamodule¶ (
Optional
[LightningDataModule
]) – A instance ofLightningDataModule
.model¶ (
LightningModule
) – Model to tune.train_dataloader¶ (
Optional
[DataLoader
]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders¶ (
Union
[DataLoader
,List
[DataLoader
],None
]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped