Shortcuts

# Source code for pytorch_lightning.metrics.functional.mean_squared_error

# Copyright The PyTorch Lightning team.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _check_same_shape

def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
_check_same_shape(preds, target)
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = target.numel()
return sum_squared_error, n_obs

def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor:
return sum_squared_error / n_obs

[docs]def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes mean squared error

Args:
pred: estimated labels
target: ground truth labels

Return:
Tensor with MSE

Example:

>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_squared_error(x, y)
tensor(0.2500)

"""
sum_squared_error, n_obs = _mean_squared_error_update(preds, target)
return _mean_squared_error_compute(sum_squared_error, n_obs)


© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision c462b274.

Built with Sphinx using a theme provided by Read the Docs.
Versions
latest
stable
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3.2
0.5.3
0.4.9
release-1.0.x