change a lot

This commit is contained in:
budui 2020-10-14 18:55:51 +08:00
parent 0927fa3de5
commit 0019d4034c
11 changed files with 261 additions and 109 deletions

View File

@ -1,34 +1,38 @@
name: horse2zebra-CyCleGAN
engine: CyCleGAN
name: selfie2anime-cycleGAN
engine: CycleGAN
result_dir: ./result
max_pairs: 266800
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: False
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: CyCle-Generator
_type: CycleGAN-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
num_blocks: 9
padding_mode: reflect
norm_type: IN
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
loss:
gan:
@ -41,17 +45,21 @@ loss:
weight: 10.0
id:
level: 1
weight: 0
weight: 10.0
mgc:
weight: 5
optimizers:
generator:
_type: Adam
lr: 2e-4
lr: 0.0001
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 2e-4
lr: 1e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
@ -60,15 +68,15 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 6
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/trainA"
root_b: "/data/i2i/horse2zebra/trainB"
root_a: "/data/i2i/selfie2anime/trainA"
root_b: "/data/i2i/selfie2anime/trainB"
random_pair: True
pipeline:
- Load
@ -82,16 +90,17 @@ data:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 4
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/testA"
root_b: "/data/i2i/horse2zebra/testB"
root_a: "/data/i2i/selfie2anime/testA"
root_b: "/data/i2i/selfie2anime/testB"
random_pair: False
pipeline:
- Load
@ -101,3 +110,15 @@ data:
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -78,7 +78,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 4
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
@ -102,7 +102,7 @@ data:
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False

View File

@ -1,26 +1,23 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.gan import GANLoss
from model.GAN.base import GANImageBuffer
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
class CycleGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss())
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
@ -56,21 +53,19 @@ class TAFGEngineKernel(EngineKernel):
images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.config.loss.id.weight > 0:
if self.id_loss.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in ["a2b", "b2a"]:
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
self.discriminators[phase[-1]](generated[phase]), True)
if self.config.loss.id.weight > 0:
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
for ph in "ab":
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(
self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
@ -97,5 +92,5 @@ class TAFGEngineKernel(EngineKernel):
def run(task, config, _):
kernel = TAFGEngineKernel(config)
kernel = CycleGANEngineKernel(config)
run_kernel(task, config, kernel)

View File

@ -1,38 +1,31 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from engine.util.container import LossContainer
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class UGATITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))

25
engine/util/loss.py Normal file
View File

@ -0,0 +1,25 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
return GANLoss(**gan_loss_cfg).to(idist.device())
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))

View File

@ -1,3 +1,4 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
@ -5,17 +6,59 @@ import torch.nn as nn
def gaussian_radial_basis_function(x, mu, sigma):
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
mu = mu.to(x.device)
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
class ImporveMyLoss(torch.nn.Module):
def __init__(self, device=idist.device()):
super().__init__()
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device)
self.x_mu_list = mu.repeat(9).view(-1, 81)
self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
self.R = torch.eye(81).to(device)
def batch_ERSMI(self, I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mat_K = kernel_F(I1, self.x_mu_list, 1)
mat_L = kernel_F(I2, self.y_mu_list, 1)
mat_k_l = mat_K * mat_L
H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2)
h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size
small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2)
h_hat = 0.5 * H1 + 0.5 * h_hat
alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat
ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
ersmi = -ersmi.squeeze().mean()
return ersmi
def forward(self, fakeI, realI):
return self.batch_ERSMI(fakeI, realI)
class MyLoss(torch.nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, fakeI, realI):
fakeI = fakeI.cuda()
realI = realI.cuda()
def batch_ERSMI(I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
@ -49,6 +92,7 @@ class MyLoss(torch.nn.Module):
alpha = alpha.matmul(h2)
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
alpha) - 1).squeeze()
ersmi = -ersmi.mean()
return ersmi
@ -61,16 +105,17 @@ class MGCLoss(nn.Module):
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
"""
def __init__(self, beta=0.5, lambda_=0.05):
def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()):
super().__init__()
self.beta = beta
self.lambda_ = lambda_
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
self.mu_x = mu.repeat(9)
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1)
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
self.mu_x = mu_x.flatten().to(device)
self.mu_y = mu_y.flatten().to(device)
self.R = torch.eye(81).unsqueeze(0).to(device)
@staticmethod
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_):
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R):
assert img1.size() == img2.size()
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
@ -79,33 +124,102 @@ class MGCLoss(nn.Module):
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
mat_k_mul_mat_l = mat_k * mat_l
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2)
h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel
h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2)
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
R = torch.eye(h_hat.size(1)).to(img1.device)
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat)
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
return rSMI
alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat
rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
return rSMI.squeeze()
def forward(self, fake, real):
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_)
return -rSMI.squeeze().mean()
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
return -rSMI.mean()
if __name__ == '__main__':
mg = MGCLoss().to("cuda")
mg = MGCLoss(device=torch.device("cpu"))
my = MyLoss().to("cuda")
imy = ImporveMyLoss()
from data.transform import transform_pipeline
def norm(x):
x -= x.min()
x /= x.max()
return (x - 0.5) * 2
pipeline = transform_pipeline(
['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}])
img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg")
img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg")
img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg")
img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg")
img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg")
img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg")
x1 = norm(torch.randn(5, 3, 256, 256))
x2 = norm(x1 * 2 + 1)
x3 = norm(torch.randn(5, 3, 256, 256))
x4 = norm(torch.exp(x3))
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4))
img_a1.requires_grad_(True)
img_a2.requires_grad_(True)
img_a3.requires_grad_(True)
# print("MyLoss")
# l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
# l = (l1+l2+l3)/3
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
#
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
#
# print("---")
# l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
print("MGCLoss")
l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
l = (l1 + l2 + l3) / 3
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
img_a1.grad = None
img_a2.grad = None
img_a3.grad = None
print("---")
l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
# print("\nMGCLoss")
# mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
#
# print("---")
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
#
# import pprofile
#
# profiler = pprofile.Profile()
# with profiler:
# iter_times = 1000
# for _ in range(iter_times):
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# profiler.print_stats()

View File

@ -1,3 +1,4 @@
from model.registry import MODEL, NORMALIZATION
import model.base.normalization
import model.image_translation
import model.image_translation.UGATIT
import model.image_translation.CycleGAN

View File

@ -59,7 +59,12 @@ class Conv2dBlock(nn.Module):
self.activation_type = activation_type
self.pre_activation = pre_activation
conv = nn.ConvTranspose2d if use_transpose_conv else nn.Conv2d
if use_transpose_conv:
# Only "zeros" padding mode is supported for ConvTranspose2d
conv_kwargs["padding_mode"] = "zeros"
conv = nn.ConvTranspose2d
else:
conv = nn.Conv2d
if pre_activation:
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)

View File

@ -21,7 +21,7 @@ class Encoder(nn.Module):
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
activation_type=activation_type, norm_type=down_conv_norm_type
))
self.out_channels = multiple_now * base_channels
@ -62,7 +62,7 @@ class Decoder(nn.Module):
for i in range(num_up_sampling):
if use_transpose_conv:
sequence.append(Conv2dBlock(
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
padding=padding, output_padding=padding,
padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type,
@ -90,7 +90,7 @@ class Decoder(nn.Module):
@MODEL.register_module("CycleGAN-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
padding_mode='reflect', norm_type="IN", pre_activation=True, use_transpose_conv=True):
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
super().__init__()
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
@ -106,7 +106,7 @@ class Generator(nn.Module):
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __int__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
@ -118,7 +118,7 @@ class PatchDiscriminator(nn.Module):
)]
multiple_now = 1
for i in range(1, num_conv + 1):
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
@ -143,3 +143,9 @@ class PatchDiscriminator(nn.Module):
return tuple(intermediate_feature)
else:
return self.sequence(x)
if __name__ == '__main__':
g = Generator(**dict(in_channels=3, out_channels=3))
print(g)
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
print(pd)

View File

@ -6,19 +6,6 @@ from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class CAMClassifier(nn.Module):
def __init__(self, in_channels, activation_type="ReLU"):
super(CAMClassifier, self).__init__()

View File

@ -1,8 +1,10 @@
import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
from types import ModuleType
import warnings
from types import ModuleType
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
class _Registry:
def __init__(self, name):
@ -136,8 +138,11 @@ class Registry(_Registry):
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
if self._module_dict[module_name] == module_class:
warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class')
return
raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}'
f'so {module_class} can not be registered')
self._module_dict[module_name] = module_class
def register_module(self, name=None, force=False, module=None):