raycv/util/misc.py
2020-09-26 17:47:24 +08:00

96 lines
3.1 KiB
Python

import logging
from pathlib import Path
from typing import Optional
import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def setup_logger(
name: Optional[str] = None,
level: int = logging.INFO,
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None,
distributed_rank: Optional[int] = None,
) -> logging.Logger:
"""Setups logger: name, level, format etc.
Args:
name (str, optional): new name for the logger. If None, the standard logger is used.
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.
Returns:
logging.Logger
For example, to improve logs readability when training with a trainer and evaluator:
.. code-block:: python
from ignite.utils import setup_logger
trainer = ...
evaluator = ...
trainer.logger = setup_logger("trainer")
evaluator.logger = setup_logger("evaluator")
trainer.run(data, max_epochs=10)
# Logs will look like
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
# ...
"""
logger = logging.getLogger(name)
# don't propagate to ancestors
# the problem here is to attach handlers to loggers
# should we provide a default configuration less open ?
if name is not None:
logger.propagate = False
# Remove previous handlers
if logger.hasHandlers():
for h in list(logger.handlers):
logger.removeHandler(h)
formatter = logging.Formatter(logger_format)
if distributed_rank is None:
import ignite.distributed as idist
distributed_rank = idist.get_rank()
if distributed_rank > 0:
logger.addHandler(logging.NullHandler())
else:
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(formatter)
logger.addHandler(ch)
if filepath is not None and Path(filepath).parent.exists():
fh = logging.FileHandler(filepath)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
else:
logger.warning("not set file logger")
return logger