67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from ignite.distributed import utils as idist
|
|
from ignite.distributed.comp_models import native as idist_native
|
|
from ignite.utils import setup_logger
|
|
|
|
|
|
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
|
|
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
|
|
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
|
|
|
|
Internally, we perform to following:
|
|
|
|
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
|
|
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
|
|
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
import ignite.distribted as idist
|
|
|
|
model = idist.auto_model(model)
|
|
|
|
In addition with NVidia/Apex, it can be used in the following way:
|
|
|
|
.. code-block:: python
|
|
|
|
import ignite.distribted as idist
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
|
model = idist.auto_model(model)
|
|
|
|
Args:
|
|
model (torch.nn.Module): model to adapt.
|
|
|
|
Returns:
|
|
torch.nn.Module
|
|
|
|
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
|
|
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
|
"""
|
|
logger = setup_logger(__name__ + ".auto_model")
|
|
|
|
# Put model's parameters to device if its parameters are not on the device
|
|
device = idist.device()
|
|
if not all([p.device == device for p in model.parameters()]):
|
|
model.to(device)
|
|
|
|
# distributed data parallel model
|
|
if idist.get_world_size() > 1:
|
|
if idist.backend() == idist_native.NCCL:
|
|
lrank = idist.get_local_rank()
|
|
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
|
|
elif idist.backend() == idist_native.GLOO:
|
|
logger.info("Apply torch DistributedDataParallel on model")
|
|
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
|
|
|
|
# not distributed but multiple GPUs reachable so data parallel model
|
|
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
|
|
logger.info("Apply torch DataParallel on model")
|
|
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
|
|
|
|
return model
|