import torch import torch.nn as nn from ignite.distributed import utils as idist from ignite.distributed.comp_models import native as idist_native from ignite.utils import setup_logger def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module: """Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). Internally, we perform to following: - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device. - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1. - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available. Examples: .. code-block:: python import ignite.distribted as idist model = idist.auto_model(model) In addition with NVidia/Apex, it can be used in the following way: .. code-block:: python import ignite.distribted as idist model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model = idist.auto_model(model) Args: model (torch.nn.Module): model to adapt. Returns: torch.nn.Module .. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel .. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel """ logger = setup_logger(__name__ + ".auto_model") # Put model's parameters to device if its parameters are not on the device device = idist.device() if not all([p.device == device for p in model.parameters()]): model.to(device) # distributed data parallel model if idist.get_world_size() > 1: if idist.backend() == idist_native.NCCL: lrank = idist.get_local_rank() logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank)) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs) elif idist.backend() == idist_native.GLOO: logger.info("Apply torch DistributedDataParallel on model") model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs) # not distributed but multiple GPUs reachable so data parallel model elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type: logger.info("Apply torch DataParallel on model") model = torch.nn.parallel.DataParallel(model, **additional_kwargs) return model