pytorch_lightning.core.saving module¶
-
class
pytorch_lightning.core.saving.
ModelIO
[source]¶ Bases:
object
-
classmethod
load_from_checkpoint
(checkpoint_path, *args, map_location=None, hparams_file=None, tags_csv=None, **kwargs)[source]¶ Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under module_arguments
Any arguments specified through *args and **kwargs will override args stored in hparams.
- Parameters
checkpoint_path¶ (
str
) – Path to checkpoint. This can also be a URL.args¶ – Any positional args needed to init the model.
map_location¶ (
Union
[Dict
[str
,str
],str
,device
,int
,Callable
,None
]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as intorch.load()
.hparams_file¶ (
Optional
[str
]) –Optional path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a
dict
and passed into yourLightningModule
for use.If your model’s hparams argument is
Namespace
and .yaml file has hierarchical structure, you need to refactor your model to treat hparams asdict
..csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
Warning
Deprecated since version 0.7.6.
tags_csv argument is deprecated in v0.7.6. Will be removed v0.9.0.
Optional path to a .csv file with two columns (key, value) as in this example:
key,value drop_prob,0.2 batch_size,32
Use this method to pass in a .csv file with the hparams you’d like to use.
hparam_overrides¶ – A dictionary with keys to override in the hparams
kwargs¶ – Any keyword args needed to init the model.
- Returns
LightningModule
with loaded weights and hyperparameters (if available).
Example
# load weights without mapping ... MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path: NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
-
classmethod
load_from_metrics
(weights_path, tags_csv, map_location=None)[source]¶ Warning
Deprecated in version 0.7.0. You should use
load_from_checkpoint()
instead. Will be removed in v0.9.0.
-
on_hpc_load
(checkpoint)[source]¶ Hook to do whatever you need right before Slurm manager loads the model.
-
on_hpc_save
(checkpoint)[source]¶ Hook to do whatever you need right before Slurm manager saves the model.
-
on_load_checkpoint
(checkpoint)[source]¶ Do something with the checkpoint. Gives model a chance to load something before
state_dict
is restored.
-
classmethod
-
pytorch_lightning.core.saving.
_convert_loaded_hparams
(model_args, hparams_type=None)[source]¶ Convert hparams according given type in callable or string (past) format
- Return type
Load hparams from a file.
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_csv = os.path.join('.', 'testing-hparams.csv') >>> save_hparams_to_tags_csv(path_csv, hparams) >>> hparams_new = load_hparams_from_tags_csv(path_csv) >>> vars(hparams) == hparams_new True >>> os.remove(path_csv)
-
pytorch_lightning.core.saving.
load_hparams_from_yaml
(config_yaml)[source]¶ Load hparams from a file.
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) >>> hparams_new = load_hparams_from_yaml(path_yaml) >>> vars(hparams) == hparams_new True >>> os.remove(path_yaml)
- Return type
None
-
pytorch_lightning.core.saving.
update_hparams
(hparams, updates)[source]¶ Overrides hparams with new values
>>> hparams = {'c': 4} >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1}) >>> hparams['a']['b'], hparams['c'] (2, 1) >>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7}) >>> hparams['a']['b'], hparams['c'] (4, 7)