pytorch_lightning.trainer.trainer module¶
-
class
pytorch_lightning.trainer.trainer.
Trainer
(logger=True, checkpoint_callback=True, early_stop_callback=False, callbacks=None, default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, num_processes=1, gpus=None, auto_select_gpus=False, num_tpu_cores=None, log_gpu_memory=None, progress_bar_refresh_rate=1, overfit_pct=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, min_steps=None, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=1.0, log_save_interval=100, row_log_interval=10, add_row_log_interval=None, distributed_backend=None, precision=32, print_nan_grads=False, weights_summary='full', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, progress_bar_callback=True, terminate_on_nan=False, auto_scale_batch_size=False, amp_level='O1', default_save_path=None, gradient_clip=None, nb_gpu_nodes=None, max_nb_epochs=None, min_nb_epochs=None, use_amp=None, show_progress_bar=None, nb_sanity_val_steps=None, **kwargs)[source]¶ Bases:
pytorch_lightning.trainer.training_io.TrainerIOMixin
,pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin
,pytorch_lightning.trainer.auto_mix_precision.TrainerAMPMixin
,pytorch_lightning.trainer.distrib_parts.TrainerDPMixin
,pytorch_lightning.trainer.distrib_data_parallel.TrainerDDPMixin
,pytorch_lightning.trainer.logging.TrainerLoggingMixin
,pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin
,pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
,pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin
,pytorch_lightning.trainer.evaluation_loop.TrainerEvaluationLoopMixin
,pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin
,pytorch_lightning.trainer.callback_config.TrainerCallbackConfigMixin
,pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin
,pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin
,pytorch_lightning.trainer.deprecated_api.TrainerDeprecatedAPITillVer0_8
,pytorch_lightning.trainer.deprecated_api.TrainerDeprecatedAPITillVer0_9
Customize every aspect of training via flags
- Parameters
logger¶ (
Union
[LightningLoggerBase
,Iterable
[LightningLoggerBase
],bool
]) – Logger (or iterable collection of loggers) for experiment tracking.checkpoint_callback¶ (
Union
[ModelCheckpoint
,bool
]) – Callback for checkpointing.early_stop_callback¶ (
pytorch_lightning.callbacks.EarlyStopping
) –callbacks¶ (
Optional
[List
[Callback
]]) – Add a list of callbacks.default_root_dir¶ (
Optional
[str
]) – Default path for logs and weights when no logger/ckpt_callback passeddefault_save_path¶ –
Warning
Deprecated since version 0.7.3.
Use default_root_dir instead. Will remove 0.9.0.
gradient_clip¶ –
Warning
Deprecated since version 0.7.0.
Use gradient_clip_val instead. Will remove 0.9.0.
process_position¶ (
int
) – orders the progress bar when running multiple models on same machine.num_nodes¶ (
int
) – number of GPU nodes for distributed training.nb_gpu_nodes¶ –
Warning
Deprecated since version 0.7.0.
Use num_nodes instead. Will remove 0.9.0.
gpus¶ (
Union
[List
[int
],str
,int
,None
]) – Which GPUs to train on.auto_select_gpus¶ (
bool
) – If enabled and gpus is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them.num_tpu_cores¶ (
Optional
[int
]) – How many TPU cores to train on (1 or 8).log_gpu_memory¶ (
Optional
[str
]) – None, ‘min_max’, ‘all’. Might slow performanceshow_progress_bar¶ –
Warning
Deprecated since version 0.7.2.
Set progress_bar_refresh_rate to positive integer to enable. Will remove 0.9.0.
progress_bar_refresh_rate¶ (
int
) – How often to refresh progress bar (in steps). Value0
disables progress bar. Ignored when a custom callback is passed tocallbacks
.overfit_pct¶ (
float
) – How much of training-, validation-, and test dataset to check.track_grad_norm¶ (
int
) – -1 no tracking. Otherwise tracks that normcheck_val_every_n_epoch¶ (
int
) – Check val every n train epochs.fast_dev_run¶ (
bool
) – runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).accumulate_grad_batches¶ (
Union
[int
,Dict
[int
,int
],List
[list
]]) – Accumulates grads every k batches or as set up in the dict.max_epochs¶ (
int
) – Stop training once this number of epochs is reached.max_nb_epochs¶ –
Warning
Deprecated since version 0.7.0.
Use max_epochs instead. Will remove 0.9.0.
min_epochs¶ (
int
) – Force training for at least these many epochsmin_nb_epochs¶ –
Warning
Deprecated since version 0.7.0.
Use min_epochs instead. Will remove 0.9.0.
max_steps¶ (
Optional
[int
]) – Stop training after this number of steps. Disabled by default (None).min_steps¶ (
Optional
[int
]) – Force training for at least these number of steps. Disabled by default (None).train_percent_check¶ (
float
) – How much of training dataset to check.val_percent_check¶ (
float
) – How much of validation dataset to check.test_percent_check¶ (
float
) – How much of test dataset to check.val_check_interval¶ (
float
) – How often within one training epoch to check the validation setrow_log_interval¶ (
int
) – How often to add logging rows (does not write to disk)add_row_log_interval¶ –
Warning
Deprecated since version 0.7.0.
Use row_log_interval instead. Will remove 0.9.0.
distributed_backend¶ (
Optional
[str
]) – The distributed backend to use.use_amp¶ –
Warning
Deprecated since version 0.7.0.
Use precision instead. Will remove 0.9.0.
precision¶ (
int
) – Full precision (32), half precision (16).Warning
Deprecated since version 0.7.2.
Has no effect. When detected, NaN grads will be printed automatically. Will remove 0.9.0.
weights_summary¶ (
Optional
[str
]) – Prints a summary of the weights when training begins.weights_save_path¶ (
Optional
[str
]) – Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in default_root_dir.amp_level¶ (
str
) – The optimization level to use (O1, O2, etc…).num_sanity_val_steps¶ (
int
) – Sanity check runs n batches of val before starting the training routine.nb_sanity_val_steps¶ –
Warning
Deprecated since version 0.7.0.
Use num_sanity_val_steps instead. Will remove 0.8.0.
truncated_bptt_steps¶ (
Optional
[int
]) – Truncated back prop breaks performs backprop every k steps ofresume_from_checkpoint¶ (
Optional
[str
]) – To resume training from a specific checkpoint pass in the path here.profiler¶ (
Union
[BaseProfiler
,bool
,None
]) – To profile individual steps during training and assist inreload_dataloaders_every_epoch¶ (
bool
) – Set to True to reload dataloaders every epochauto_lr_find¶ (
Union
[bool
,str
]) – If set to True, will initially run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name.replace_sampler_ddp¶ (
bool
) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is usedterminate_on_nan¶ (
bool
) – If set to True, will terminate training (by raising a ValueError) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf.auto_scale_batch_size¶ (
Union
[str
,bool
]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.hparams.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search.
-
_Trainer__attach_dataloaders
(model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None)[source]¶
-
_Trainer__set_random_port
()[source]¶ When running DDP NOT managed by SLURM, the ports might collide :return:
-
classmethod
add_argparse_args
(parent_parser)[source]¶ Extends existing argparse by default Trainer attributes.
- Parameters
parent_parser¶ (
ArgumentParser
) – The custom cli arguments parser, which will be extended by the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will extend the parent_parser.
Examples
>>> import argparse >>> import pprint >>> parser = argparse.ArgumentParser() >>> parser = Trainer.add_argparse_args(parser) >>> args = parser.parse_args([]) >>> pprint.pprint(vars(args)) {... 'check_val_every_n_epoch': 1, 'checkpoint_callback': True, 'default_root_dir': None, 'deterministic': False, 'distributed_backend': None, 'early_stop_callback': False, ... 'logger': True, 'max_epochs': 1000, 'max_steps': None, 'min_epochs': 1, 'min_steps': None, ... 'profiler': None, 'progress_bar_callback': True, 'progress_bar_refresh_rate': 1, ...}
- Return type
-
check_model_configuration
(model)[source]¶ Checks that the model is configured correctly before training is started.
- Parameters
model¶ (
LightningModule
) – The model to test.
-
fit
(model, train_dataloader=None, val_dataloaders=None)[source]¶ Runs the full optimization routine.
- Parameters
model¶ (
LightningModule
) – Model to fit.train_dataloader¶ (
Optional
[DataLoader
]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders¶ (
Union
[DataLoader
,List
[DataLoader
],None
]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped
Example:
# Option 1, # Define the train_dataloader() and val_dataloader() fxs # in the lightningModule # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY trainer = Trainer() model = LightningModule() trainer.fit(model) # Option 2 # in production cases we might want to pass different datasets to the same model # Recommended for PRODUCTION SYSTEMS train, val = DataLoader(...), DataLoader(...) trainer = Trainer() model = LightningModule() trainer.fit(model, train_dataloader=train, val_dataloader=val) # Option 1 & 2 can be mixed, for example the training set can be # defined as part of the model, and validation can then be feed to .fit()
-
classmethod
from_argparse_args
(args, **kwargs)[source]¶ create an instance from CLI arguments
Example
>>> parser = ArgumentParser(add_help=False) >>> parser = Trainer.add_argparse_args(parser) >>> args = Trainer.parse_argparser(parser.parse_args("")) >>> trainer = Trainer.from_argparse_args(args)
- Return type
-
classmethod
get_deprecated_arg_names
()[source]¶ Returns a list with deprecated Trainer arguments.
- Return type
-
classmethod
get_init_arguments_and_types
()[source]¶ Scans the Trainer signature and returns argument names, types and default values.
- Returns
(argument name, set with argument types, argument default value).
- Return type
List with tuples of 3 values
Examples
>>> args = Trainer.get_init_arguments_and_types() >>> import pprint >>> pprint.pprint(sorted(args)) [('accumulate_grad_batches', (<class 'int'>, typing.Dict[int, int], typing.List[list]), 1), ... ('callbacks', (typing.List[pytorch_lightning.callbacks.base.Callback], <class 'NoneType'>), None), ('check_val_every_n_epoch', (<class 'int'>,), 1), ... ('max_epochs', (<class 'int'>,), 1000), ... ('precision', (<class 'int'>,), 32), ('print_nan_grads', (<class 'bool'>,), False), ('process_position', (<class 'int'>,), 0), ('profiler', (<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>, <class 'bool'>, <class 'NoneType'>), None), ...
-
static
parse_argparser
(arg_parser)[source]¶ Parse CLI arguments, required for custom bool types.
- Return type
-
run_pretrain_routine
(model)[source]¶ Sanity check a few things before starting actual training.
- Parameters
model¶ (
LightningModule
) – The model to run sanity test on.
-
test
(model=None, test_dataloaders=None)[source]¶ Separates from fit to make sure you never run on your test set until you want to.
- Parameters
model¶ (
Optional
[LightningModule
]) – The model to test.test_dataloaders¶ (
Union
[DataLoader
,List
[DataLoader
],None
]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples.
Example:
# Option 1 # run test after fitting test = DataLoader(...) trainer = Trainer() model = LightningModule() trainer.fit(model) trainer.test(test_dataloaders=test) # Option 2 # run test from a loaded model test = DataLoader(...) model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') trainer = Trainer() trainer.test(model, test_dataloaders=test)
-
DEPRECATED_IN_0_8
= ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', 'add_row_log_interval', 'nb_sanity_val_steps', 'tng_tqdm_dic')[source]¶
-
class
pytorch_lightning.trainer.trainer.
_PatchDataLoader
(dataloader)[source]¶ Bases:
object
Callable object for patching dataloaders passed into trainer.fit(). Use this class to override model.*_dataloader() and be pickle-compatible.
- Parameters
dataloader¶ (
Union
[List
[DataLoader
],DataLoader
]) – Dataloader object to return when called.