Shortcuts

Research projects tend to test different approaches to the same dataset. This is very easy to do in Lightning with inheritance.

For example, imagine we now want to train an AutoEncoder to use as a feature extractor for images. The only things that change in the LitAutoEncoder model are the init, forward, training, validation and test step.

class Encoder(torch.nn.Module):
    ...


class Decoder(torch.nn.Module):
    ...


class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.decoder(self.encoder(x))


class LitAutoEncoder(LightningModule):
    def __init__(self, auto_encoder):
        super().__init__()
        self.auto_encoder = auto_encoder
        self.metric = torch.nn.MSELoss()

    def forward(self, x):
        return self.auto_encoder.encoder(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_hat = self.auto_encoder(x)
        loss = self.metric(x, x_hat)
        return loss

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, "test")

    def _shared_eval(self, batch, batch_idx, prefix):
        x, _ = batch
        x_hat = self.auto_encoder(x)
        loss = self.metric(x, x_hat)
        self.log(f"{prefix}_loss", loss)

and we can train this using the Trainer:

auto_encoder = AutoEncoder()
lightning_module = LitAutoEncoder(auto_encoder)
trainer = Trainer()
trainer.fit(lightning_module, train_dataloader, val_dataloader)

And remember that the forward method should define the practical use of a LightningModule. In this case, we want to use the LitAutoEncoder to extract image representations:

some_images = torch.Tensor(32, 1, 28, 28)
representations = lightning_module(some_images)

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision be581598.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
future-structure
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.