utils
Contains PyTorch specific utility methods.
Module
Functions
autodetect_gpu
def autodetect_gpu() ‑> Dict[str, Any]:
Detects and returns GPU accelerator and device count.
Returns A dictionary with the keys 'accelerator' and 'devices' which should be passed to the PyTorchLightning Trainer.
enhanced_torch_load
def enhanced_torch_load( f: Union[str, os.PathLike, BinaryIO, IO[bytes]], map_location: Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str], ForwardRef(None)] = None, pickle_module: Any = None, *, weights_only: bool = True, **pickle_load_args: Any,) ‑> Any:
Call torch.load()
with sensible parameters.
See the docs of torch.load()
for more information.
has_mps
def has_mps() ‑> bool:
Detect if MPS is available and torch can use it.
Classes
LoggerType
class LoggerType(value, names=None, *, module=None, qualname=None, type=None, start=1):
Different types of loggers for PyTorchLightning.
With the exception of CSVLogger and TensorBoardLogger, all loggers need to have their corresponding python libraries installed separately.
More information about PyTorchLightning loggers can be found here: https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html