datamodule¶
Functions
A decorator that checks if prepare_data/setup have been called. |
Classes
A DataModule standardizes the training, val, test splits, data preparation and transforms. |
LightningDataModule for loading DataLoaders with ease.
-
class
pytorch_lightning.core.datamodule.
LightningDataModule
(*args, **kwargs)[source]¶ Bases:
pytorch_lightning.core.hooks.CheckpointHooks
,pytorch_lightning.core.hooks.DataHooks
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example:
class MyDataModule(LightningDataModule): def __init__(self): super().__init__() def prepare_data(self): # download, split, etc... # only called on 1 GPU/TPU in distributed def setup(self): # make assignments here (val/train/test split) # called on every process in DDP def train_dataloader(self): train_split = Dataset(...) return DataLoader(train_split) def val_dataloader(self): val_split = Dataset(...) return DataLoader(val_split) def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split)
A DataModule implements 5 key methods:
prepare_data (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
setup (things to do on every accelerator in distributed mode).
train_dataloader the training dataloader.
val_dataloader the val dataloader(s).
test_dataloader the test dataloader(s).
This allows you to share a full dataset without explaining how to download, split transform and process the data
-
classmethod
add_argparse_args
(parent_parser)[source]¶ Extends existing argparse by default LightningDataModule attributes.
- Return type
-
classmethod
from_argparse_args
(args, **kwargs)[source]¶ Create an instance from CLI arguments.
- Parameters
args¶ (
Union
[Namespace
,ArgumentParser
]) – The parser or namespace to take arguments from. Only known arguments will be parsed and passed to theLightningDataModule
.**kwargs¶ – Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments.
Example:
parser = ArgumentParser(add_help=False) parser = LightningDataModule.add_argparse_args(parser) module = LightningDataModule.from_argparse_args(args)
-
classmethod
from_datasets
(train_dataset=None, val_dataset=None, test_dataset=None, batch_size=1, num_workers=0)[source]¶ Create an instance from torch.utils.data.Dataset.
- Parameters
train_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],Mapping
[str
,Dataset
],None
]) – (optional) Dataset to be used for train_dataloader()val_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],None
]) – (optional) Dataset or list of Dataset to be used for val_dataloader()test_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],None
]) – (optional) Dataset or list of Dataset to be used for test_dataloader()batch_size¶ (
int
) – Batch size to use for each dataloader. Default is 1.num_workers¶ (
int
) – Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Number of CPUs available.
-
classmethod
get_init_arguments_and_types
()[source]¶ Scans the DataModule signature and returns argument names, types and default values.
- Returns
(argument name, set with argument types, argument default value).
- Return type
List with tuples of 3 values
-
abstract
prepare_data
(*args, **kwargs)[source]¶ Use this to download and prepare data.
Warning
DO NOT set state to the model (use setup instead) since this is NOT called on every GPU in DDP/TPU
Example:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()
In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):
Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node Trainer(prepare_data_per_node=True) # call on GLOBAL_RANK=0 (great for shared file systems) Trainer(prepare_data_per_node=False)
This is called before requesting the dataloaders:
model.prepare_data() if ddp/tpu: init() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader()
-
size
(dim=None)[source]¶ Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.
-
property
dims
¶ A tuple describing the shape of your data. Extra functionality exposed in
size
.
-
property
has_prepared_data
¶ Return bool letting you know if datamodule.prepare_data() has been called or not.
- Returns
True if datamodule.prepare_data() has been called. False by default.
- Return type
-
property
has_setup_fit
¶ Return bool letting you know if datamodule.setup(‘fit’) has been called or not.
- Returns
True if datamodule.setup(‘fit’) has been called. False by default.
- Return type
-
property
has_setup_test
¶ Return bool letting you know if datamodule.setup(‘test’) has been called or not.
- Returns
True if datamodule.setup(‘test’) has been called. False by default.
- Return type
-
property
test_transforms
¶ Optional transforms (or collection of transforms) you can apply to test dataset
-
property
train_transforms
¶ Optional transforms (or collection of transforms) you can apply to train dataset
-
property
val_transforms
¶ Optional transforms (or collection of transforms) you can apply to validation dataset
-
pytorch_lightning.core.datamodule.
track_data_hook_calls
(fn)[source]¶ A decorator that checks if prepare_data/setup have been called.
When dm.prepare_data() is called, dm.has_prepared_data gets set to True
When dm.setup(‘fit’) is called, dm.has_setup_fit gets set to True
When dm.setup(‘test’) is called, dm.has_setup_test gets set to True
When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True
- Parameters
fn¶ (function) – Function that will be tracked to see if it has been called.
- Returns
Decorated function that tracks its call status and saves it to private attrs in its obj instance.
- Return type
function