Skip to main content

utils

Contains PyTorch specific utility methods.

Module

Functions

autodetect_gpu

def autodetect_gpu()> Dict[str, Any]:

Detects and returns GPU accelerator and device count.

Returns A dictionary with the keys 'accelerator' and 'devices' which should be passed to the PyTorchLightning Trainer.

enhanced_torch_load

def enhanced_torch_load(    f: Union[str, os.PathLike, BinaryIO, IO[bytes]],    map_location: Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str], ForwardRef(None)] = None,    pickle_module: Any = None,    *,    weights_only: bool = True,    **pickle_load_args: Any,)> Any:

Call torch.load() with sensible parameters.

See the docs of torch.load() for more information.

has_mps

def has_mps()> bool:

Detect if MPS is available and torch can use it.

Classes

LoggerType

class LoggerType(value, names=None, *, module=None, qualname=None, type=None, start=1):

Different types of loggers for PyTorchLightning.

With the exception of CSVLogger and TensorBoardLogger, all loggers need to have their corresponding python libraries installed separately.

More information about PyTorchLightning loggers can be found here: https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html

Ancestors

Variables

  • static CSVLogger
  • static MLFlow
  • static Neptune
  • static TensorBoard
  • static WeightsAndBiases