Shortcuts

pytorch_lightning.core.saving module

class pytorch_lightning.core.saving.ModelIO[source]

Bases: object

classmethod _load_model_state(checkpoint, *cls_args, **cls_kwargs)[source]
classmethod load_from_checkpoint(checkpoint_path, *args, map_location=None, hparams_file=None, tags_csv=None, **kwargs)[source]

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under module_arguments

Any arguments specified through *args and **kwargs will override args stored in hparams.

Parameters
  • checkpoint_path (str) – Path to checkpoint. This can also be a URL.

  • args – Any positional args needed to init the model.

  • map_location (Union[Dict[str, str], str, device, int, Callable, None]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file (Optional[str]) –

    Optional path to a .yaml file with hierarchical structure as in this example:

    drop_prob: 0.2
    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.

    If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.

    .csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.

  • tags_csv (Optional[str]) –

    Warning

    Deprecated since version 0.7.6.

    tags_csv argument is deprecated in v0.7.6. Will be removed v0.9.0.

    Optional path to a .csv file with two columns (key, value) as in this example:

    key,value
    drop_prob,0.2
    batch_size,32
    

    Use this method to pass in a .csv file with the hparams you’d like to use.

  • hparam_overrides – A dictionary with keys to override in the hparams

  • kwargs – Any keyword args needed to init the model.

Returns

LightningModule with loaded weights and hyperparameters (if available).

Example

# load weights without mapping ...
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path: NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
classmethod load_from_metrics(weights_path, tags_csv, map_location=None)[source]

Warning

Deprecated in version 0.7.0. You should use load_from_checkpoint() instead. Will be removed in v0.9.0.

on_hpc_load(checkpoint)[source]

Hook to do whatever you need right before Slurm manager loads the model.

Parameters

checkpoint (Dict[str, Any]) – A dictionary with variables from the checkpoint.

Return type

None

on_hpc_save(checkpoint)[source]

Hook to do whatever you need right before Slurm manager saves the model.

Parameters

checkpoint (Dict[str, Any]) – A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.

Return type

None

on_load_checkpoint(checkpoint)[source]

Do something with the checkpoint. Gives model a chance to load something before state_dict is restored.

Parameters

checkpoint (Dict[str, Any]) – A dictionary with variables from the checkpoint.

Return type

None

on_save_checkpoint(checkpoint)[source]

Give the model a chance to add something to the checkpoint. state_dict is already there.

Parameters

checkpoint (Dict[str, Any]) – A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.

Return type

None

CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters'[source]
CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name'[source]
CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type'[source]
pytorch_lightning.core.saving._convert_loaded_hparams(model_args, hparams_type=None)[source]

Convert hparams according given type in callable or string (past) format

Return type

object

pytorch_lightning.core.saving.convert(val)[source]
Return type

Union[int, float, bool, str]

pytorch_lightning.core.saving.load_hparams_from_tags_csv(tags_csv)[source]

Load hparams from a file.

>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_csv = os.path.join('.', 'testing-hparams.csv')
>>> save_hparams_to_tags_csv(path_csv, hparams)
>>> hparams_new = load_hparams_from_tags_csv(path_csv)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_csv)
Return type

Dict[str, Any]

pytorch_lightning.core.saving.load_hparams_from_yaml(config_yaml)[source]

Load hparams from a file.

>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_yaml = './testing-hparams.yaml'
>>> save_hparams_to_yaml(path_yaml, hparams)
>>> hparams_new = load_hparams_from_yaml(path_yaml)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_yaml)
Return type

Dict[str, Any]

pytorch_lightning.core.saving.save_hparams_to_tags_csv(tags_csv, hparams)[source]
Return type

None

pytorch_lightning.core.saving.save_hparams_to_yaml(config_yaml, hparams)[source]
Parameters
Return type

None

pytorch_lightning.core.saving.update_hparams(hparams, updates)[source]

Overrides hparams with new values

>>> hparams = {'c': 4}
>>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
>>> hparams['a']['b'], hparams['c']
(2, 1)
>>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
>>> hparams['a']['b'], hparams['c']
(4, 7)
Parameters
  • hparams (dict) – the original params and also target object

  • updates (dict) – new params to be used as update

Return type

None