Shortcuts

datamodule

Classes

LightningDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms.

LightningDataModule for loading DataLoaders with ease.

class pytorch_lightning.core.datamodule.LightningDataModule(*args: Any, **kwargs: Any)[source]

Bases: pytorch_lightning.core.hooks.CheckpointHooks, pytorch_lightning.core.hooks.DataHooks, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
    def train_dataloader(self):
        train_split = Dataset(...)
        return DataLoader(train_split)
    def val_dataloader(self):
        val_split = Dataset(...)
        return DataLoader(val_split)
    def test_dataloader(self):
        test_split = Dataset(...)
        return DataLoader(test_split)
    def teardown(self):
        # clean up after fit or test
        # called on every process in DDP

A DataModule implements 6 key methods:

  • prepare_data (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).

  • setup (things to do on every accelerator in distributed mode).

  • train_dataloader the training dataloader.

  • val_dataloader the val dataloader(s).

  • test_dataloader the test dataloader(s).

  • teardown (things to do on every accelerator in distributed mode when finished)

This allows you to share a full dataset without explaining how to download, split, transform, and process the data

classmethod add_argparse_args(parent_parser, **kwargs)[source]

Extends existing argparse by default LightningDataModule attributes.

Return type

ArgumentParser

classmethod from_argparse_args(args, **kwargs)[source]

Create an instance from CLI arguments.

Parameters
  • args (Union[Namespace, ArgumentParser]) – The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the LightningDataModule.

  • **kwargs – Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments.

Example:

parser = ArgumentParser(add_help=False)
parser = LightningDataModule.add_argparse_args(parser)
module = LightningDataModule.from_argparse_args(args)
classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, batch_size=1, num_workers=0)[source]

Create an instance from torch.utils.data.Dataset.

Parameters
  • train_dataset (Union[Dataset, Sequence[Dataset], Mapping[str, Dataset], None]) – (optional) Dataset to be used for train_dataloader()

  • val_dataset (Union[Dataset, Sequence[Dataset], None]) – (optional) Dataset or list of Dataset to be used for val_dataloader()

  • test_dataset (Union[Dataset, Sequence[Dataset], None]) – (optional) Dataset or list of Dataset to be used for test_dataloader()

  • batch_size (int) – Batch size to use for each dataloader. Default is 1.

  • num_workers (int) – Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Number of CPUs available.

classmethod get_init_arguments_and_types()[source]

Scans the DataModule signature and returns argument names, types and default values.

Returns

(argument name, set with argument types, argument default value).

Return type

List with tuples of 3 values

size(dim=None)[source]

Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.

Return type

Union[Tuple, int]

property dims

A tuple describing the shape of your data. Extra functionality exposed in size.

property has_prepared_data: bool

Return bool letting you know if datamodule.prepare_data() has been called or not.

Returns

True if datamodule.prepare_data() has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_setup_fit: bool

Return bool letting you know if datamodule.setup(stage='fit') has been called or not.

Returns

True if datamodule.setup(stage='fit') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_setup_predict: bool

Return bool letting you know if datamodule.setup(stage='predict') has been called or not.

Returns

True if datamodule.setup(stage='predict') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_setup_test: bool

Return bool letting you know if datamodule.setup(stage='test') has been called or not.

Returns

True if datamodule.setup(stage='test') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_setup_validate: bool

Return bool letting you know if datamodule.setup(stage='validate') has been called or not.

Returns

True if datamodule.setup(stage='validate') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_teardown_fit: bool

Return bool letting you know if datamodule.teardown(stage='fit') has been called or not.

Returns

True if datamodule.teardown(stage='fit') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_teardown_predict: bool

Return bool letting you know if datamodule.teardown(stage='predict') has been called or not.

Returns

True if datamodule.teardown(stage='predict') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_teardown_test: bool

Return bool letting you know if datamodule.teardown(stage='test') has been called or not.

Returns

True if datamodule.teardown(stage='test') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property has_teardown_validate: bool

Return bool letting you know if datamodule.teardown(stage='validate') has been called or not.

Returns

True if datamodule.teardown(stage='validate') has been called. False by default.

Return type

bool

Deprecated since version v1.4: Will be removed in v1.6.0.

property test_transforms

Optional transforms (or collection of transforms) you can apply to test dataset

property train_transforms

Optional transforms (or collection of transforms) you can apply to train dataset

property val_transforms

Optional transforms (or collection of transforms) you can apply to validation dataset