Examples
Template model definition
In 99% of cases you want to just copy this template to start a new lightningModule and change the core of what your model is actually trying to do.
1 2 | # get a copy of the module template
wget https://raw.githubusercontent.com/williamFalcon/pytorch-lightning/master/examples/new_project_templates/lightning_module_template.py
|
Trainer Example
__main__ function
Normally, we want to let the __main__ function start the training. Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a chance to add hyperparameters.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | from test_tube import HyperOptArgumentParser if __name__ == '__main__': # use default args given by lightning root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0] parent_parser = HyperOptArgumentParser(strategy='random_search', add_help=False) add_default_args(parent_parser, root_dir) # allow model to overwrite or extend args parser = ExampleModel.add_model_specific_args(parent_parser) hyperparams = parser.parse_args() # train model main(hyperparams) |
Main Function
The main function is your entry into the program. This is where you init your model, checkpoint directory, and launch the training.
The main function should have 3 arguments:
- hparams: a configuration of hyperparameters.
- slurm_manager: Slurm cluster manager object (can be None)
- dict: for you to return any values you want (useful in meta-learning, otherwise set to _)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | def main(hparams, cluster, results_dict): """ Main training routine specific for this project :param hparams: :return: """ # init experiment log_dir = os.path.dirname(os.path.realpath(__file__)) exp = Experiment( name='test_tube_exp', debug=True, save_dir=log_dir, version=0, autosave=False, description='test demo' ) # set the hparams for the experiment exp.argparse(hparams) exp.save() # build model model = MyLightningModule(hparams) # callbacks early_stop = EarlyStopping( monitor=hparams.early_stop_metric, patience=hparams.early_stop_patience, verbose=True, mode=hparams.early_stop_mode ) model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version) checkpoint = ModelCheckpoint( filepath=model_save_path, save_function=None, save_best_only=True, verbose=True, monitor=hparams.model_save_monitor_value, mode=hparams.model_save_monitor_mode ) # configure trainer trainer = Trainer( experiment=exp, cluster=cluster, checkpoint_callback=checkpoint, early_stop_callback=early_stop, ) # train model trainer.fit(model) |
The main function will start training on your main function. If you use the HyperParameterOptimizer in hyper parameter optimization mode, this main function will get one set of hyperparameters. If you use it as a simple argument parser you get the default arguments in the argument parser.
So, calling main(hyperparams) runs the model with the default argparse arguments.
1 | main(hyperparams) |
CPU hyperparameter search
1 2 3 4 5 6 | # run a grid search over 20 hyperparameter combinations. hyperparams.optimize_parallel_cpu( main_local, nb_trials=20, nb_workers=1 ) |
Hyperparameter search on a single or multiple GPUs
1 2 3 4 5 6 7 | # run a grid search over 20 hyperparameter combinations. hyperparams.optimize_parallel_gpu( main_local, nb_trials=20, nb_workers=1, gpus=[0,1,2,3] ) |
Hyperparameter search on a SLURM HPC cluster
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | def optimize_on_cluster(hyperparams): # enable cluster training cluster = SlurmCluster( hyperparam_optimizer=hyperparams, log_path=hyperparams.tt_save_path, test_tube_exp_name=hyperparams.tt_name ) # email for cluster coms cluster.notify_job_status(email='add_email_here', on_done=True, on_fail=True) # configure cluster cluster.per_experiment_nb_gpus = hyperparams.per_experiment_nb_gpus cluster.job_time = '48:00:00' cluster.gpu_type = '1080ti' cluster.memory_mb_per_node = 48000 # any modules for code to run in env cluster.add_command('source activate pytorch_lightning') # name of exp job_display_name = hyperparams.tt_name.split('_')[0] job_display_name = job_display_name[0:3] # run hopt print('submitting jobs...') cluster.optimize_parallel_cluster_gpu( main, nb_trials=hyperparams.nb_hopt_trials, job_name=job_display_name ) # run cluster hyperparameter search optimize_on_cluster(hyperparams) |