Skip to main content

loss

Implements loss for pytorch modules.

Module

Functions

soft_dice_loss

def soft_dice_loss(    pred: numpy.ndarray,    targets: numpy.ndarray,    square_nom: bool = False,    square_denom: bool = False,    weight: Union[Sequence, torch.Tensor, ForwardRef(None)] = None,    smooth: float = 1.0,)> torch.Tensor:

Functional implementation of the SoftDiceLoss.

Arguments

  • pred: A numpy array of predictions.
  • targets: A numpy array of targets.
  • square_nom: Whether to square the nominator.
  • square_denom: Whether to square the denominator.
  • weight: Additional weighting of individual classes.
  • smooth: Smoothing for nominator and denominator.

Returns A torch tensor with the computed dice loss.

Classes

SoftDiceLoss

class SoftDiceLoss(    square_nom: bool = False,    square_denom: bool = False,    weight: Union[Sequence, torch.Tensor, ForwardRef(None)] = None,    smooth: float = 1.0,):

Soft Dice Loss.

The soft dice loss is computed as a fraction of nominator over denominator, where: nominator is 2 * the area of overlap between targets and predictions plus a smooth factor,and the denominator is the total number of pixels in both images plus the smooth factor.If weights are provided the fraction is multiplied by the provided weights for each class.If either square_nom or square_denom are provided, then the respective nominator or denominator will be raised to the power of 2.

Arguments

  • square_nom: Whether to square the nominator. Optional.
  • square_denom: Whether to square the denominator. Optional.
  • weight: Additional weighting of individual classes. Optional.
  • smooth: Smoothing for nominator and denominator. Optional.Defaults to 1.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Variables

  • static dump_patches : bool
  • static training : bool

Methods


forward

def forward(self, predictions: torch.Tensor, targets: torch.Tensor)> torch.Tensor:

Computes Soft Dice Loss.

Arguments

  • predictions: The predictions obtained by the network.
  • targets: The targets (ground truth) for the predictions.

Returns torch.Tensor: The computed loss value

Raises

  • ValueError: If the predictions tensor has less than 3 dimensions.
  • ValueError: If the targets tensor has less than 2 dimensions.