116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
import importlib
|
|
import logging
|
|
import pkgutil
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
def add_spectral_norm(module):
|
|
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
|
return nn.utils.spectral_norm(module)
|
|
else:
|
|
return module
|
|
|
|
|
|
def import_submodules(package, recursive=True):
|
|
""" Import all submodules of a module, recursively, including subpackages
|
|
|
|
:param package: package (name or actual module)
|
|
:type package: str | module
|
|
:rtype: dict[str, types.ModuleType]
|
|
"""
|
|
if isinstance(package, str):
|
|
package = importlib.import_module(package)
|
|
results = {}
|
|
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
|
|
full_name = package.__name__ + '.' + name
|
|
results[name] = importlib.import_module(full_name)
|
|
if recursive and is_pkg:
|
|
results.update(import_submodules(full_name))
|
|
return results
|
|
|
|
|
|
def setup_logger(
|
|
name: Optional[str] = None,
|
|
level: int = logging.INFO,
|
|
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
|
filepath: Optional[str] = None,
|
|
distributed_rank: Optional[int] = None,
|
|
) -> logging.Logger:
|
|
"""Setups logger: name, level, format etc.
|
|
|
|
Args:
|
|
name (str, optional): new name for the logger. If None, the standard logger is used.
|
|
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
|
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
|
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
|
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
|
If None, distributed_rank is initialized to the rank of process.
|
|
|
|
Returns:
|
|
logging.Logger
|
|
|
|
For example, to improve logs readability when training with a trainer and evaluator:
|
|
|
|
.. code-block:: python
|
|
|
|
from ignite.utils import setup_logger
|
|
|
|
trainer = ...
|
|
evaluator = ...
|
|
|
|
trainer.logger = setup_logger("trainer")
|
|
evaluator.logger = setup_logger("evaluator")
|
|
|
|
trainer.run(data, max_epochs=10)
|
|
|
|
# Logs will look like
|
|
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
|
|
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
|
|
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
|
|
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
|
|
# ...
|
|
|
|
"""
|
|
logger = logging.getLogger(name)
|
|
|
|
# don't propagate to ancestors
|
|
# the problem here is to attach handlers to loggers
|
|
# should we provide a default configuration less open ?
|
|
if name is not None:
|
|
logger.propagate = False
|
|
|
|
# Remove previous handlers
|
|
if logger.hasHandlers():
|
|
for h in list(logger.handlers):
|
|
logger.removeHandler(h)
|
|
|
|
formatter = logging.Formatter(logger_format)
|
|
|
|
if distributed_rank is None:
|
|
import ignite.distributed as idist
|
|
|
|
distributed_rank = idist.get_rank()
|
|
|
|
if distributed_rank > 0:
|
|
logger.addHandler(logging.NullHandler())
|
|
else:
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
ch = logging.StreamHandler()
|
|
ch.setLevel(level)
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
|
|
if filepath is not None and Path(filepath).parent.exists():
|
|
fh = logging.FileHandler(filepath)
|
|
fh.setLevel(logging.DEBUG)
|
|
fh.setFormatter(formatter)
|
|
logger.addHandler(fh)
|
|
else:
|
|
logger.warning("not set file logger")
|
|
|
|
return logger
|