Skip to content

Examples

Template model definition

In 99% of cases you want to just copy this template to start a new lightningModule and change the core of what your model is actually trying to do.

1
2
# get a copy of the module template
wget https://raw.githubusercontent.com/williamFalcon/pytorch-lightning/master/examples/new_project_templates/lightning_module_template.py

Trainer Example

__main__ function

Normally, we want to let the __main__ function start the training. Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a chance to add hyperparameters.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from test_tube import HyperOptArgumentParser

if __name__ == '__main__':

    # use default args given by lightning
    root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0]
    parent_parser = HyperOptArgumentParser(strategy='random_search', add_help=False)
    add_default_args(parent_parser, root_dir)

    # allow model to overwrite or extend args
    parser = ExampleModel.add_model_specific_args(parent_parser)
    hyperparams = parser.parse_args()

    # train model
    main(hyperparams)

Main Function

The main function is your entry into the program. This is where you init your model, checkpoint directory, and launch the training. The main function should have 3 arguments:
- hparams: a configuration of hyperparameters.
- slurm_manager: Slurm cluster manager object (can be None) - dict: for you to return any values you want (useful in meta-learning, otherwise set to _)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def main(hparams, cluster, results_dict):
    """
    Main training routine specific for this project
    :param hparams:
    :return:
    """
    # init experiment
    log_dir = os.path.dirname(os.path.realpath(__file__))
    exp = Experiment(
        name='test_tube_exp',
        debug=True,
        save_dir=log_dir,
        version=0,
        autosave=False,
        description='test demo'
    )

    # set the hparams for the experiment
    exp.argparse(hparams)
    exp.save()

    # build model
    model = MyLightningModule(hparams)

    # callbacks
    early_stop = EarlyStopping(
        monitor=hparams.early_stop_metric,
        patience=hparams.early_stop_patience,
        verbose=True,
        mode=hparams.early_stop_mode
    )

    model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
    checkpoint = ModelCheckpoint(
        filepath=model_save_path,
        save_function=None,
        save_best_only=True,
        verbose=True,
        monitor=hparams.model_save_monitor_value,
        mode=hparams.model_save_monitor_mode
    )

    # configure trainer
    trainer = Trainer(
        experiment=exp,
        cluster=cluster,
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stop,
    )

    # train model
    trainer.fit(model)

The main function will start training on your main function. If you use the HyperParameterOptimizer in hyper parameter optimization mode, this main function will get one set of hyperparameters. If you use it as a simple argument parser you get the default arguments in the argument parser.

So, calling main(hyperparams) runs the model with the default argparse arguments.

1
main(hyperparams)

1
2
3
4
5
6
# run a grid search over 20 hyperparameter combinations.
hyperparams.optimize_parallel_cpu(
    main_local,
    nb_trials=20,
    nb_workers=1
)

Hyperparameter search on a single or multiple GPUs

1
2
3
4
5
6
7
# run a grid search over 20 hyperparameter combinations.
hyperparams.optimize_parallel_gpu(
    main_local,
    nb_trials=20,
    nb_workers=1,
    gpus=[0,1,2,3]
)

Hyperparameter search on a SLURM HPC cluster

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def optimize_on_cluster(hyperparams):
    # enable cluster training
    cluster = SlurmCluster(
        hyperparam_optimizer=hyperparams,
        log_path=hyperparams.tt_save_path,
        test_tube_exp_name=hyperparams.tt_name
    )

    # email for cluster coms
    cluster.notify_job_status(email='add_email_here', on_done=True, on_fail=True)

    # configure cluster
    cluster.per_experiment_nb_gpus = hyperparams.per_experiment_nb_gpus
    cluster.job_time = '48:00:00'
    cluster.gpu_type = '1080ti'
    cluster.memory_mb_per_node = 48000

    # any modules for code to run in env
    cluster.add_command('source activate pytorch_lightning')

    # name of exp
    job_display_name = hyperparams.tt_name.split('_')[0]
    job_display_name = job_display_name[0:3]

    # run hopt
    print('submitting jobs...')
    cluster.optimize_parallel_cluster_gpu(
        main,
        nb_trials=hyperparams.nb_hopt_trials,
        job_name=job_display_name
    )

# run cluster hyperparameter search    
optimize_on_cluster(hyperparams)