diff --git a/.idea/deployment.xml b/.idea/deployment.xml index 8ccfb5e..d56324a 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index 4fe5ebf..bf1eb1a 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -14,11 +14,15 @@ handler: n_saved: 2 tensorboard: scalar: 100 # log scalar `scalar` times per epoch - image: 2 # log image `image` times per epoch + image: 4 # log image `image` times per epoch + test: + random: True + images: 10 model: generator: _type: UGATIT-Generator + _add_spectral_norm: True in_channels: 3 out_channels: 3 base_channels: 64 @@ -27,11 +31,13 @@ model: light: True local_discriminator: _type: UGATIT-Discriminator + _add_spectral_norm: True in_channels: 3 base_channels: 64 num_blocks: 5 global_discriminator: _type: UGATIT-Discriminator + _add_spectral_norm: True in_channels: 3 base_channels: 64 num_blocks: 7 @@ -50,6 +56,8 @@ loss: weight: 10.0 cam: weight: 1000 + mgc: + weight: 0 optimizers: generator: @@ -70,7 +78,7 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 24 + batch_size: 4 shuffle: True num_workers: 2 pin_memory: True diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py index ed9eebc..860905d 100644 --- a/engine/U-GAT-IT.py +++ b/engine/U-GAT-IT.py @@ -1,16 +1,15 @@ -from omegaconf import OmegaConf - +import ignite.distributed as idist import torch import torch.nn as nn import torch.nn.functional as F -import ignite.distributed as idist +from omegaconf import OmegaConf -from loss.gan import GANLoss -from model.GAN.UGATIT import RhoClipper -from model.GAN.base import GANImageBuffer -from util.image import attention_colored_map from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.util.build import build_model +from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss +from loss.gan import GANLoss +from model.image_translation.UGATIT import RhoClipper +from util.image import attention_colored_map def mse_loss(x, target_flag): @@ -30,9 +29,8 @@ class UGATITEngineKernel(EngineKernel): self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss() + self.mgc_loss = MyLoss() self.rho_clipper = RhoClipper(0, 1) - self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in - self.discriminators.keys()} self.train_generator_first = False def build_models(self) -> (dict, dict): @@ -82,6 +80,9 @@ class UGATITEngineKernel(EngineKernel): loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase]) loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"]) + if self.config.loss.mgc.weight > 0: + loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss( + batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"]) for dk in "lg": generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) diff --git a/engine/base/i2i.py b/engine/base/i2i.py index 95c5897..74822e6 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -64,7 +64,7 @@ class EngineKernel(object): self.engine = engine def build_models(self) -> (dict, dict): - raise NotImplemented + raise NotImplementedError def to_save(self): to_save = {} @@ -73,19 +73,19 @@ class EngineKernel(object): return to_save def setup_after_g(self): - raise NotImplemented + raise NotImplementedError def setup_before_g(self): - raise NotImplemented + raise NotImplementedError def forward(self, batch, inference=False) -> dict: - raise NotImplemented + raise NotImplementedError def criterion_generators(self, batch, generated) -> dict: - raise NotImplemented + raise NotImplementedError def criterion_discriminators(self, batch, generated) -> dict: - raise NotImplemented + raise NotImplementedError def intermediate_images(self, batch, generated) -> dict: """ @@ -94,7 +94,7 @@ class EngineKernel(object): :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ - raise NotImplemented + raise NotImplementedError def change_engine(self, config, engine: Engine): pass diff --git a/engine/util/build.py b/engine/util/build.py index 6e59e22..b423586 100644 --- a/engine/util/build.py +++ b/engine/util/build.py @@ -1,18 +1,21 @@ -import torch import ignite.distributed as idist - +import torch +import torch.optim as optim from omegaconf import OmegaConf from model import MODEL -import torch.optim as optim +from util.misc import add_spectral_norm def build_model(cfg): cfg = OmegaConf.to_container(cfg) bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False) + add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False) model = MODEL.build_with(cfg) if bn_to_sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if add_spectral_norm_flag: + model.apply(add_spectral_norm) return idist.auto_model(model) diff --git a/loss/I2I/minimal_geometry_distortion_constraint_loss.py b/loss/I2I/minimal_geometry_distortion_constraint_loss.py new file mode 100644 index 0000000..ed662c3 --- /dev/null +++ b/loss/I2I/minimal_geometry_distortion_constraint_loss.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + + +def gaussian_radial_basis_function(x, mu, sigma): + # (kernel_size) -> (batch_size, kernel_size, c*h*w) + mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3)) + mu = mu.to(x.device) + # (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w) + x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1) + return torch.exp((x - mu).pow(2) / (2 * sigma ** 2)) + + +class MyLoss(torch.nn.Module): + def __init__(self): + super(MyLoss, self).__init__() + + def forward(self, fakeI, realI): + def batch_ERSMI(I1, I2): + batch_size = I1.shape[0] + img_size = I1.shape[1] * I1.shape[2] * I1.shape[3] + if I2.shape[1] == 1 and I1.shape[1] != 1: + I2 = I2.repeat(1, 3, 1, 1) + + def kernel_F(y, mu_list, sigma): + tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784] + tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1) + tmp_y = tmp_mu - tmp_y + mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2)) + return mat_L + + mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda() + + x_mu_list = mu.repeat(9).view(-1, 81) + y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81) + + mat_K = kernel_F(I1, x_mu_list, 1) + mat_L = kernel_F(I2, y_mu_list, 1) + + H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / ( + img_size ** 2)).cuda() + H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda() + h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / ( + img_size ** 2)).cuda() + H2 = 0.5 * H1 + 0.5 * H2 + tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda() + alpha = (tmp.inverse()) + + alpha = alpha.matmul(h2) + ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul( + alpha) - 1).squeeze() + ersmi = -ersmi.mean() + return ersmi + + batch_loss = batch_ERSMI(fakeI, realI) + return batch_loss + + +class MGCLoss(nn.Module): + """ + Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ + """ + + def __init__(self, beta=0.5, lambda_=0.05): + super().__init__() + self.beta = beta + self.lambda_ = lambda_ + mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]) + self.mu_x = mu.repeat(9) + self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1) + + @staticmethod + def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_): + assert img1.size() == img2.size() + + num_pixel = img1.size(1) * img1.size(2) * img2.size(3) + + mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1) + mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1) + + mat_k_mul_mat_l = mat_k * mat_l + h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel + h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2) + small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2) + + R = torch.eye(h_hat.size(1)).to(img1.device) + alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat) + + rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1 + return rSMI + + def forward(self, fake, real): + rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_) + return -rSMI.squeeze().mean() + + +if __name__ == '__main__': + mg = MGCLoss().to("cuda") + + + def norm(x): + x -= x.min() + x /= x.max() + return (x - 0.5) * 2 + + + x1 = norm(torch.randn(5, 3, 256, 256)) + x2 = norm(x1 * 2 + 1) + x3 = norm(torch.randn(5, 3, 256, 256)) + x4 = norm(torch.exp(x3)) + print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4)) diff --git a/model/GAN/CycleGAN.py b/model/GAN/CycleGAN.py deleted file mode 100644 index 61cc3be..0000000 --- a/model/GAN/CycleGAN.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch.nn as nn - -from model.normalization import select_norm_layer -from model.registry import MODEL -from .base import ResidualBlock - - -@MODEL.register_module("CyCle-Generator") -class Generator(nn.Module): - def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect', - norm_type="IN"): - super(Generator, self).__init__() - assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.' - norm_layer = select_norm_layer(norm_type) - use_bias = norm_type == "IN" - - self.start_conv = nn.Sequential( - nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, - bias=use_bias), - norm_layer(num_features=base_channels), - nn.ReLU(inplace=True) - ) - - # down sampling - submodules = [] - num_down_sampling = 2 - for i in range(num_down_sampling): - multiple = 2 ** i - submodules += [ - nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, - kernel_size=3, stride=2, padding=1, bias=use_bias), - norm_layer(num_features=base_channels * multiple * 2), - nn.ReLU(inplace=True) - ] - self.encoder = nn.Sequential(*submodules) - - res_block_channels = num_down_sampling ** 2 * base_channels - self.resnet_middle = nn.Sequential( - *[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in - range(num_blocks)]) - - # up sampling - submodules = [] - for i in range(num_down_sampling): - multiple = 2 ** (num_down_sampling - i) - submodules += [ - nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2, - padding=1, output_padding=1, bias=use_bias), - norm_layer(num_features=base_channels * multiple // 2), - nn.ReLU(inplace=True), - ] - self.decoder = nn.Sequential(*submodules) - - self.end_conv = nn.Sequential( - nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode), - nn.Tanh() - ) - - def forward(self, x): - x = self.encoder(self.start_conv(x)) - x = self.resnet_middle(x) - return self.end_conv(self.decoder(x)) diff --git a/model/GAN/MUNIT.py b/model/GAN/MUNIT.py deleted file mode 100644 index c113cf0..0000000 --- a/model/GAN/MUNIT.py +++ /dev/null @@ -1,150 +0,0 @@ -import torch -import torch.nn as nn - -from model import MODEL -from model.GAN.base import Conv2dBlock, ResBlock -from model.normalization import select_norm_layer - - -class StyleEncoder(nn.Module): - def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False, - max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): - super(StyleEncoder, self).__init__() - - sequence = [Conv2dBlock( - in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, - use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type - )] - - multiple_now = 1 - for i in range(1, num_conv + 1): - multiple_prev = multiple_now - multiple_now = min(2 ** i, 2 ** max_multiple) - sequence.append(Conv2dBlock( - multiple_prev * base_channels, multiple_now * base_channels, - kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, - use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type - )) - sequence.append(nn.AdaptiveAvgPool2d(1)) - # conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code - sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0)) - self.model = nn.Sequential(*sequence) - - def forward(self, x): - return self.model(x).view(x.size(0), -1) - - -class ContentEncoder(nn.Module): - def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False, - padding_mode='reflect', activation_type="ReLU", norm_type="IN"): - super().__init__() - sequence = [Conv2dBlock( - in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, - use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type - )] - - for i in range(num_down_sampling): - sequence.append(Conv2dBlock( - base_channels * (2 ** i), base_channels * (2 ** (i + 1)), - kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, - use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type - )) - - sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type, - activation_type) for _ in range(num_res_blocks)] - self.sequence = nn.Sequential(*sequence) - - def forward(self, x): - return self.sequence(x) - - -class Decoder(nn.Module): - def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks, - use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", - padding_mode='reflect'): - super(Decoder, self).__init__() - self.res_norm_type = res_norm_type - self.res_blocks = nn.ModuleList([ - ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type) - for _ in range(num_res_blocks) - ]) - sequence = list() - channels = in_channels - for i in range(num_up_sampling): - sequence.append(nn.Sequential( - nn.Upsample(scale_factor=2), - Conv2dBlock(channels, channels // 2, - kernel_size=5, stride=1, padding=2, padding_mode=padding_mode, - use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type - ), - )) - channels = channels // 2 - sequence.append( - Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect", - use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE")) - self.sequence = nn.Sequential(*sequence) - - def forward(self, x): - for blk in self.res_blocks: - x = blk(x) - return self.sequence(x) - - -class Fusion(nn.Module): - def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"): - super().__init__() - norm_layer = select_norm_layer(norm_type) - self.start_fc = nn.Sequential( - nn.Linear(in_features, base_features), - norm_layer(base_features), - nn.ReLU(True), - ) - self.fcs = nn.Sequential(*[ - nn.Sequential( - nn.Linear(base_features, base_features), - norm_layer(base_features), - nn.ReLU(True), - ) for _ in range(n_blocks - 2) - ]) - self.end_fc = nn.Sequential( - nn.Linear(base_features, out_features), - ) - - def forward(self, x): - x = self.start_fc(x) - x = self.fcs(x) - return self.end_fc(x) - - -@MODEL.register_module("MUNIT-Generator") -class Generator(nn.Module): - def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv, - num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks, - use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'): - super().__init__() - self.num_decoder_res_blocks = num_decoder_res_blocks - self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels, - use_spectral_norm, padding_mode, activation_type, norm_type="IN") - self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm, - padding_mode, activation_type, norm_type="NONE") - content_channels = base_channels * (2 ** 2) - self.decoder = Decoder(content_channels, out_channels, num_sampling, - num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN", - activation_type=activation_type, padding_mode=padding_mode) - self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2, - base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE") - - def encode(self, x): - return self.content_encoder(x), self.style_encoder(x) - - def decode(self, content, style): - as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1) - # set style for decoder - for i, blk in enumerate(self.decoder.res_blocks): - blk.conv1.normalization.set_style(as_param_style[2 * i]) - blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) - return self.decoder(content) - - def forward(self, x): - content, style = self.encode(x) - return self.decode(content, style) diff --git a/model/GAN/TAFG.py b/model/GAN/TAFG.py deleted file mode 100644 index a1fb0ec..0000000 --- a/model/GAN/TAFG.py +++ /dev/null @@ -1,171 +0,0 @@ -import torch -import torch.nn as nn -from torchvision.models import vgg19 - -from model.normalization import select_norm_layer -from model.registry import MODEL -from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder -from .base import ResBlock - - -class VGG19StyleEncoder(nn.Module): - def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE", - vgg19_layers=(0, 5, 10, 19), fix_vgg19=True): - super().__init__() - self.vgg19_layers = vgg19_layers - self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1] - self.vgg19.requires_grad_(not fix_vgg19) - - norm_layer = select_norm_layer(norm_type) - - self.conv0 = nn.Sequential( - nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, - bias=True), - norm_layer(base_channels), - nn.ReLU(True), - ) - self.conv = nn.ModuleList([ - nn.Sequential( - nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1, - padding_mode=padding_mode, bias=True), - norm_layer(base_channels), - nn.ReLU(True), - ) for i in range(1, 4) - ]) - self.pool = nn.AdaptiveAvgPool2d(1) - self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0) - - def fixed_style_features(self, x): - features = [] - for i in range(len(self.vgg19)): - x = self.vgg19[i](x) - if i in self.vgg19_layers: - features.append(x) - return features - - def forward(self, x): - fsf = self.fixed_style_features(x) - x = self.conv0(x) - for i, l in enumerate(self.conv): - x = l(torch.cat([x, fsf[i]], dim=1)) - x = self.pool(torch.cat([x, fsf[-1]], dim=1)) - x = self.conv1x1(x) - return x.view(x.size(0), -1) - - -@MODEL.register_module("TAFG-ResGenerator") -class ResGenerator(nn.Module): - def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64): - super().__init__() - self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks, - use_spectral_norm=use_spectral_norm) - resnet_channels = 2 ** 2 * base_channels - self.decoder = Decoder(resnet_channels, out_channels, 2, - 0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect") - - def forward(self, x): - return self.decoder(self.content_encoder(x)) - - -@MODEL.register_module("TAFG-SingleGenerator") -class SingleGenerator(nn.Module): - def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False, - style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8, - num_res_blocks=8, base_channels=64, padding_mode="reflect"): - super().__init__() - self.num_adain_blocks = num_adain_blocks - if style_encoder_type == "StyleEncoder": - self.style_encoder = StyleEncoder( - style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm, - max_multiple=4, padding_mode=padding_mode, norm_type="NONE" - ) - elif style_encoder_type == "VGG19StyleEncoder": - self.style_encoder = VGG19StyleEncoder( - style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE" - ) - else: - raise NotImplemented(f"do not support {style_encoder_type}") - - resnet_channels = 2 ** 2 * base_channels - self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, - n_blocks=3, norm_type="NONE") - self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks, - use_spectral_norm=use_spectral_norm) - - self.decoder = Decoder(resnet_channels, out_channels, 2, - num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode) - - def forward(self, content_img, style_img): - content = self.content_encoder(content_img) - style = self.style_encoder(style_img) - as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1) - # set style for decoder - for i, blk in enumerate(self.decoder.res_blocks): - blk.conv1.normalization.set_style(as_param_style[2 * i]) - blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) - return self.decoder(content) - - -@MODEL.register_module("TAFG-Generator") -class Generator(nn.Module): - def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False, - style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8, - num_res_blocks=8, base_channels=64, padding_mode="reflect"): - super(Generator, self).__init__() - self.num_adain_blocks = num_adain_blocks - if style_encoder_type == "StyleEncoder": - self.style_encoders = nn.ModuleDict(dict( - a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm, - max_multiple=4, padding_mode=padding_mode, norm_type="NONE"), - b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm, - max_multiple=4, padding_mode=padding_mode, norm_type="NONE"), - )) - elif style_encoder_type == "VGG19StyleEncoder": - self.style_encoders = nn.ModuleDict(dict( - a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, - norm_type="NONE"), - b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, - norm_type="NONE", fix_vgg19=False) - )) - else: - raise NotImplemented(f"do not support {style_encoder_type}") - resnet_channels = 2 ** 2 * base_channels - self.style_converters = nn.ModuleDict(dict( - a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3, - norm_type="NONE"), - b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3, - norm_type="NONE"), - )) - self.content_encoders = nn.ModuleDict({ - "a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm), - "b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm) - }) - - self.content_resnet = nn.Sequential(*[ - ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN") - for _ in range(num_res_blocks) - ]) - self.decoders = nn.ModuleDict(dict( - a=Decoder(resnet_channels, out_channels, 2, - num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode), - b=Decoder(resnet_channels, out_channels, 2, - num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode), - )) - - def encode(self, content_img, style_img, which_content, which_style): - content = self.content_resnet(self.content_encoders[which_content](content_img)) - style = self.style_encoders[which_style](style_img) - return content, style - - def decode(self, content, style, which): - decoder = self.decoders[which] - as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1) - # set style for decoder - for i, blk in enumerate(decoder.res_blocks): - blk.conv1.normalization.set_style(as_param_style[2 * i]) - blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) - return decoder(content) - - def forward(self, content_img, style_img, which_content, which_style): - content, style = self.encode(content_img, style_img, which_content, which_style) - return self.decode(content, style, which_style) diff --git a/model/GAN/TSIT.py b/model/GAN/TSIT.py deleted file mode 100644 index 3422a51..0000000 --- a/model/GAN/TSIT.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from model import MODEL -from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock - - -class Interpolation(nn.Module): - def __init__(self, scale_factor=None, mode='nearest', align_corners=None): - super(Interpolation, self).__init__() - self.scale_factor = scale_factor - self.mode = mode - self.align_corners = align_corners - - def forward(self, x): - return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, - recompute_scale_factor=False) - - def __repr__(self): - return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})" - - -@MODEL.register_module("TSIT-Generator") -class Generator(nn.Module): - def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7, - padding_mode="reflect", activation_type="ReLU"): - super().__init__() - self.num_blocks = num_blocks - self.base_channels = base_channels - - self.content_stream = self.build_stream(padding_mode, activation_type) - self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type, - norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3) - - sequence = [] - multiple_now = min(2 ** self.num_blocks, 2 ** 4) - for i in range(1, self.num_blocks + 1): - m = self.num_blocks - i - multiple_prev = multiple_now - multiple_now = min(2 ** m, 2 ** 4) - sequence.append(nn.Sequential( - ReverseResidualBlock( - multiple_prev * base_channels, multiple_now * base_channels, - padding_mode=padding_mode, norm_type="FADE", - additional_norm_kwargs=dict( - condition_in_channels=multiple_prev * base_channels, - base_norm_type="BN", - padding_mode=padding_mode - ) - ), - Interpolation(2, mode="nearest") - )) - self.generator = nn.Sequential(*sequence) - self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh", - kernel_size=7, padding_mode=padding_mode, padding=3) - - def build_stream(self, padding_mode, activation_type): - multiple_now = 1 - stream_sequence = [] - for i in range(1, self.num_blocks + 1): - multiple_prev = multiple_now - multiple_now = min(2 ** i, 2 ** 4) - stream_sequence.append(nn.Sequential( - Interpolation(scale_factor=0.5, mode="nearest"), - ResidualBlock( - multiple_prev * self.base_channels, multiple_now * self.base_channels, - padding_mode=padding_mode, activation_type=activation_type, norm_type="IN") - )) - return nn.ModuleList(stream_sequence) - - def forward(self, content, z=None): - c = self.start_conv(content) - content_features = [] - for i in range(self.num_blocks): - c = self.content_stream[i](c) - content_features.append(c) - if z is None: - z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) - - for i in range(self.num_blocks): - m = - i - 1 - res_block = self.generator[i][0] - res_block.conv1.normalization.set_feature(content_features[m]) - res_block.conv2.normalization.set_feature(content_features[m]) - if res_block.learn_skip_connection: - res_block.res_conv.normalization.set_feature(content_features[m]) - return self.end_conv(self.generator(z)) diff --git a/model/GAN/UGATIT.py b/model/GAN/UGATIT.py deleted file mode 100644 index f90375a..0000000 --- a/model/GAN/UGATIT.py +++ /dev/null @@ -1,236 +0,0 @@ -import torch -import torch.nn as nn -from .base import ResidualBlock -from model.registry import MODEL - - -class RhoClipper(object): - def __init__(self, clip_min, clip_max): - self.clip_min = clip_min - self.clip_max = clip_max - assert clip_min < clip_max - - def __call__(self, module): - if hasattr(module, 'rho'): - w = module.rho.data - w = w.clamp(self.clip_min, self.clip_max) - module.rho.data = w - - -@MODEL.register_module("UGATIT-Generator") -class Generator(nn.Module): - def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False): - assert (num_blocks >= 0) - super(Generator, self).__init__() - self.input_channels = in_channels - self.output_channels = out_channels - self.base_channels = base_channels - self.num_blocks = num_blocks - self.img_size = img_size - self.light = light - - down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, - padding_mode="reflect", bias=False), - nn.InstanceNorm2d(base_channels), - nn.ReLU(True)] - - n_down_sampling = 2 - for i in range(n_down_sampling): - mult = 2 ** i - down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2, - padding=1, bias=False, padding_mode="reflect"), - nn.InstanceNorm2d(base_channels * mult * 2), - nn.ReLU(True)] - - # Down-Sampling Bottleneck - mult = 2 ** n_down_sampling - for i in range(num_blocks): - down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)] - self.down_encoder = nn.Sequential(*down_encoder) - - # Class Activation Map - self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False) - self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False) - self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True) - self.relu = nn.ReLU(True) - - # Gamma, Beta block - if self.light: - fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False), - nn.ReLU(True), - nn.Linear(base_channels * mult, base_channels * mult, bias=False), - nn.ReLU(True)] - else: - fc = [ - nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False), - nn.ReLU(True), - nn.Linear(base_channels * mult, base_channels * mult, bias=False), - nn.ReLU(True)] - self.fc = nn.Sequential(*fc) - - self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False) - self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False) - - # Up-Sampling Bottleneck - self.up_bottleneck = nn.ModuleList( - [ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)]) - - # Up-Sampling - up_decoder = [] - for i in range(n_down_sampling): - mult = 2 ** (n_down_sampling - i) - up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1, - padding=1, padding_mode="reflect", bias=False), - ILN(base_channels * mult // 2), - nn.ReLU(True)] - - up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3, - padding_mode="reflect", bias=False), - nn.Tanh()] - self.up_decoder = nn.Sequential(*up_decoder) - - def forward(self, x): - x = self.down_encoder(x) - - gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) - gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) - gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3) - - gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) - gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) - gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3) - - cam_logit = torch.cat([gap_logit, gmp_logit], 1) - - x = torch.cat([gap, gmp], 1) - x = self.relu(self.conv1x1(x)) - - heatmap = torch.sum(x, dim=1, keepdim=True) - - if self.light: - x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1) - x_ = self.fc(x_.view(x_.shape[0], -1)) - else: - x_ = self.fc(x.view(x.shape[0], -1)) - gamma, beta = self.gamma(x_), self.beta(x_) - - for ub in self.up_bottleneck: - x = ub(x, gamma, beta) - - x = self.up_decoder(x) - return x, cam_logit, heatmap - - -class ResnetAdaILNBlock(nn.Module): - def __init__(self, dim, use_bias): - super(ResnetAdaILNBlock, self).__init__() - self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect") - self.norm1 = AdaILN(dim) - self.relu1 = nn.ReLU(True) - - self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect") - self.norm2 = AdaILN(dim) - - def forward(self, x, gamma, beta): - out = self.conv1(x) - out = self.norm1(out, gamma, beta) - out = self.relu1(out) - out = self.conv2(out) - out = self.norm2(out, gamma, beta) - - return out + x - - -def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5): - in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True) - out_in = (x - in_mean) / torch.sqrt(in_var + eps) - ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) - out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps) - out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln - out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) - return out - - -class AdaILN(nn.Module): - def __init__(self, num_features, eps=1e-5, default_rho=0.9): - super(AdaILN, self).__init__() - self.eps = eps - self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) - self.rho.data.fill_(default_rho) - - def forward(self, x, gamma, beta): - return instance_layer_normalization(x, gamma, beta, self.rho, self.eps) - - -class ILN(nn.Module): - def __init__(self, num_features, eps=1e-5): - super(ILN, self).__init__() - self.eps = eps - self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) - self.gamma = nn.Parameter(torch.Tensor(1, num_features)) - self.beta = nn.Parameter(torch.Tensor(1, num_features)) - self.rho.data.fill_(0.0) - self.gamma.data.fill_(1.0) - self.beta.data.fill_(0.0) - - def forward(self, x): - return instance_layer_normalization( - x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps) - - -@MODEL.register_module("UGATIT-Discriminator") -class Discriminator(nn.Module): - def __init__(self, in_channels, base_channels=64, num_blocks=5): - super(Discriminator, self).__init__() - encoder = [self.build_conv_block(in_channels, base_channels)] - - for i in range(1, num_blocks - 2): - mult = 2 ** (i - 1) - encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2)) - - mult = 2 ** (num_blocks - 2 - 1) - encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1)) - - self.encoder = nn.Sequential(*encoder) - - # Class Activation Map - mult = 2 ** (num_blocks - 2) - self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False)) - self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False)) - self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True) - self.leaky_relu = nn.LeakyReLU(0.2, True) - - self.conv = nn.utils.spectral_norm( - nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect")) - - @staticmethod - def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"): - return nn.Sequential(*[ - nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, - bias=True, padding=padding, padding_mode=padding_mode)), - nn.LeakyReLU(0.2, True), - ]) - - def forward(self, x, return_heatmap=False): - x = self.encoder(x) - batch_size = x.size(0) - - gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel - gap_logit = self.gap_fc(gap.view(batch_size, -1)) - gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3) - - gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) - gmp_logit = self.gmp_fc(gmp.view(batch_size, -1)) - gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3) - - cam_logit = torch.cat([gap_logit, gmp_logit], 1) - - x = torch.cat([gap, gmp], 1) - x = self.leaky_relu(self.conv1x1(x)) - - if return_heatmap: - heatmap = torch.sum(x, dim=1, keepdim=True) - return self.conv(x), cam_logit, heatmap - else: - return self.conv(x), cam_logit diff --git a/model/GAN/__init__.py b/model/GAN/__init__.py deleted file mode 100644 index 7195a56..0000000 --- a/model/GAN/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from util.misc import import_submodules - -__all__ = import_submodules(__name__).keys() \ No newline at end of file diff --git a/model/GAN/base.py b/model/GAN/base.py deleted file mode 100644 index 856e3cd..0000000 --- a/model/GAN/base.py +++ /dev/null @@ -1,203 +0,0 @@ -from functools import partial - -import math -import torch -import torch.nn as nn - -from model import MODEL -from model.normalization import select_norm_layer - - -class GANImageBuffer(object): - """This class implements an image buffer that stores previously - generated images. - This buffer allows us to update the discriminator using a history of - generated images rather than the ones produced by the latest generator - to reduce model oscillation. - Args: - buffer_size (int): The size of image buffer. If buffer_size = 0, - no buffer will be created. - buffer_ratio (float): The chance / possibility to use the images - previously stored in the buffer. - """ - - def __init__(self, buffer_size, buffer_ratio=0.5): - self.buffer_size = buffer_size - # create an empty buffer - if self.buffer_size > 0: - self.img_num = 0 - self.image_buffer = [] - self.buffer_ratio = buffer_ratio - - def query(self, images): - """Query current image batch using a history of generated images. - Args: - images (Tensor): Current image batch without history information. - """ - if self.buffer_size == 0: # if the buffer size is 0, do nothing - return images - return_images = [] - for image in images: - image = torch.unsqueeze(image.data, 0) - # if the buffer is not full, keep inserting current images - if self.img_num < self.buffer_size: - self.img_num = self.img_num + 1 - self.image_buffer.append(image) - return_images.append(image) - else: - use_buffer = torch.rand(1) < self.buffer_ratio - # by self.buffer_ratio, the buffer will return a previously - # stored image, and insert the current image into the buffer - if use_buffer: - random_id = torch.randint(0, self.buffer_size, (1,)).item() - image_tmp = self.image_buffer[random_id].clone() - self.image_buffer[random_id] = image - return_images.append(image_tmp) - # by (1 - self.buffer_ratio), the buffer will return the - # current image - else: - return_images.append(image) - # collect all the images and return - return_images = torch.cat(return_images, 0) - return return_images - - -# based SPADE or pix2pixHD Discriminator -@MODEL.register_module("PatchDiscriminator") -class PatchDiscriminator(nn.Module): - def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN", - need_intermediate_feature=False): - super().__init__() - self.need_intermediate_feature = need_intermediate_feature - - kernel_size = 4 - padding = math.ceil((kernel_size - 1.0) / 2) - norm_layer = select_norm_layer(norm_type) - use_bias = norm_type == "IN" - padding_mode = "zeros" - - sequence = [nn.Sequential( - nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding), - nn.LeakyReLU(0.2, False) - )] - multiple_now = 1 - for i in range(1, num_conv): - multiple_prev = multiple_now - multiple_now = min(2 ** i, 2 ** 3) - stride = 1 if i == num_conv - 1 else 2 - sequence.append(nn.Sequential( - self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now, - kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode), - norm_layer(base_channels * multiple_now), - nn.LeakyReLU(0.2, inplace=False), - )) - multiple_now = min(2 ** num_conv, 8) - sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, - padding_mode=padding_mode)) - self.conv_blocks = nn.ModuleList(sequence) - - @staticmethod - def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding, - bias=True, padding_mode: str = 'zeros'): - conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode) - if not use_spectral: - return conv - return nn.utils.spectral_norm(conv) - - def forward(self, x): - if self.need_intermediate_feature: - intermediate_feature = [] - for layer in self.conv_blocks: - x = layer(x) - intermediate_feature.append(x) - return tuple(intermediate_feature) - else: - for layer in self.conv_blocks: - x = layer(x) - return x - - -@MODEL.register_module() -class ResidualBlock(nn.Module): - def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None): - super(ResidualBlock, self).__init__() - if use_bias is None: - # Only for IN, use bias since it does not have affine parameters. - use_bias = norm_type == "IN" - norm_layer = select_norm_layer(norm_type) - self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - bias=use_bias) - self.norm1 = norm_layer(num_channels) - self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - bias=use_bias) - self.norm2 = norm_layer(num_channels) - - def forward(self, x): - res = x - x = self.relu1(self.norm1(self.conv1(x))) - x = self.norm2(self.conv2(x)) - return x + res - - -_DO_NO_THING_FUNC = lambda x: x - - -def select_activation(t): - if t == "ReLU": - return partial(nn.ReLU, inplace=True) - elif t == "LeakyReLU": - return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True) - elif t == "Tanh": - return partial(nn.Tanh) - elif t == "NONE": - return _DO_NO_THING_FUNC - else: - raise NotImplemented - - -def _use_bias_checker(norm_type): - return norm_type not in ["IN", "BN", "AdaIN"] - - -class Conv2dBlock(nn.Module): - def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU", - bias=None, norm_type="NONE", **conv_kwargs): - super().__init__() - self.norm_type = norm_type - self.activation_type = activation_type - conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias - conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs) - self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv - if norm_type != "NONE": - self.normalization = select_norm_layer(norm_type)(out_channels) - if activation_type != "NONE": - self.activation = select_activation(activation_type)() - - def forward(self, x): - x = self.convolution(x) - if self.norm_type != "NONE": - x = self.normalization(x) - if self.activation_type != "NONE": - x = self.activation(x) - return x - - -class ResBlock(nn.Module): - def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect', - norm_type="IN", activation_type="ReLU", use_bias=None): - super().__init__() - self.norm_type = norm_type - if use_bias is None: - # bias will be canceled after channel wise normalization - use_bias = _use_bias_checker(norm_type) - - self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm, - kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias, - norm_type=norm_type, activation_type=activation_type) - self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm, - kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias, - norm_type=norm_type, activation_type="NONE") - - def forward(self, x): - return self.conv2(self.conv1(x)) + x diff --git a/model/GAN/wrapper.py b/model/GAN/wrapper.py deleted file mode 100644 index f5b7538..0000000 --- a/model/GAN/wrapper.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - -from model import MODEL - - -@MODEL.register_module() -class MultiScaleDiscriminator(nn.Module): - def __init__(self, num_scale, discriminator_cfg): - super().__init__() - - self.discriminator_list = nn.ModuleList([ - MODEL.build_with(discriminator_cfg) for _ in range(num_scale) - ]) - - @staticmethod - def down_sample(x): - return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) - - def forward(self, x): - results = [] - for discriminator in self.discriminator_list: - results.append(discriminator(x)) - x = self.down_sample(x) - return results diff --git a/model/__init__.py b/model/__init__.py index aef3eef..83051a6 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,4 +1,3 @@ from model.registry import MODEL, NORMALIZATION -import model.GAN import model.base.normalization - +import model.image_translation diff --git a/model/base/module.py b/model/base/module.py index d0c6da6..7a7a765 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -30,7 +30,24 @@ def _activation(activation): elif activation == "Tanh": return nn.Tanh() else: - raise NotImplemented(activation) + raise NotImplementedError(f"{activation} not valid") + + +class LinearBlock(nn.Module): + def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"): + super().__init__() + + self.norm_type = norm_type + self.activation_type = activation_type + + bias = _use_bias_checker(norm_type) if bias is None else bias + self.linear = nn.Linear(in_features, out_features, bias) + + self.normalization = _normalization(norm_type, out_features) + self.activation = _activation(activation_type) + + def forward(self, x): + return self.activation(self.normalization(self.linear(x))) class Conv2dBlock(nn.Module): diff --git a/model/base/normalization.py b/model/base/normalization.py index 7a925f4..30b2e12 100644 --- a/model/base/normalization.py +++ b/model/base/normalization.py @@ -93,7 +93,7 @@ class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization): def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5): out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps) - out = out * gamma + beta + out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) return out @@ -115,7 +115,7 @@ class ILN(nn.Module): def forward(self, x): return _instance_layer_normalization( - x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps) + x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps) @NORMALIZATION.register_module("AdaILN") @@ -136,7 +136,6 @@ class AdaILN(nn.Module): def forward(self, x): assert self.have_set_condition - out = _instance_layer_normalization( - x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps) + out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps) self.have_set_condition = False return out diff --git a/model/image_translation/CycleGAN.py b/model/image_translation/CycleGAN.py new file mode 100644 index 0000000..e69de29 diff --git a/model/image_translation/GauGAN.py b/model/image_translation/GauGAN.py new file mode 100644 index 0000000..e69de29 diff --git a/model/image_translation/MUNIT.py b/model/image_translation/MUNIT.py new file mode 100644 index 0000000..8dc3af0 --- /dev/null +++ b/model/image_translation/MUNIT.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn + +from model import MODEL +from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock + + +def _get_down_sampling_sequence(in_channels, base_channels, num_conv, max_down_sampling_multiple=2, + padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )] + multiple_now = 1 + for i in range(1, num_conv + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple) + sequence.append(Conv2dBlock( + multiple_prev * base_channels, multiple_now * base_channels, + kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )) + return sequence, multiple_now * base_channels + + +class StyleEncoder(nn.Module): + def __init__(self, in_channels, out_dim, num_conv, base_channels=64, + max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): + super().__init__() + + sequence, last_channels = _get_down_sampling_sequence( + in_channels, base_channels, num_conv, + max_down_sampling_multiple, padding_mode, activation_type, norm_type + ) + sequence.append(nn.AdaptiveAvgPool2d(1)) + # conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code + sequence.append(nn.Conv2d(last_channels, out_dim, kernel_size=1, stride=1, padding=0)) + self.sequence = nn.Sequential(*sequence) + + def forward(self, image): + return self.sequence(image).view(image.size(0), -1) + + +class ContentEncoder(nn.Module): + def __init__(self, in_channels, num_down_sampling, num_residual_blocks, base_channels=64, + max_down_sampling_multiple=2, + padding_mode='reflect', activation_type="ReLU", norm_type="IN"): + super().__init__() + + sequence, last_channels = _get_down_sampling_sequence( + in_channels, base_channels, num_down_sampling, + max_down_sampling_multiple, padding_mode, activation_type, norm_type + ) + + sequence += [ResidualBlock(last_channels, last_channels, padding_mode, activation_type, norm_type) for _ in + range(num_residual_blocks)] + self.sequence = nn.Sequential(*sequence) + + def forward(self, image): + return self.sequence(image) + + +class Decoder(nn.Module): + def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, + res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", padding_mode='reflect'): + super().__init__() + self.residual_blocks = nn.ModuleList([ + ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type) + for _ in range(num_residual_blocks) + ]) + + sequence = list() + channels = in_channels + for i in range(num_up_sampling): + sequence.append(nn.Sequential( + nn.Upsample(scale_factor=2), + Conv2dBlock(channels, channels // 2, + kernel_size=5, stride=1, padding=2, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type), + )) + channels = channels // 2 + sequence.append(Conv2dBlock(channels, out_channels, + kernel_size=7, stride=1, padding=3, padding_mode="reflect", + activation_type="Tanh", norm_type="NONE")) + + self.up_sequence = nn.Sequential(*sequence) + + def forward(self, x, style): + as_param_style = torch.chunk(style, 2 * len(self.residual_blocks), dim=1) + # set style for decoder + for i, blk in enumerate(self.residual_blocks): + blk.conv1.normalization.set_style(as_param_style[2 * i]) + blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) + x = blk(x) + return self.up_sequence(x) + + +class MLPFusion(nn.Module): + def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"): + super().__init__() + + sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)] + sequence += [ + LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type) + for _ in range(n_blocks - 2) + ] + sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type)) + self.sequence = nn.Sequential(*sequence) + + def forward(self, x): + return self.sequence(x) + + +@MODEL.register_module("MUNIT-Generator") +class Generator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8, + num_mlp_base_feature=256, num_mlp_blocks=3, + max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2, + encoder_num_residual_blocks=4, decoder_num_residual_blocks=4, + padding_mode='reflect', activation_type="ReLU"): + super().__init__() + self.content_encoder = ContentEncoder( + in_channels, num_content_down_sampling, encoder_num_residual_blocks, + base_channels, max_down_sampling_multiple, + padding_mode, activation_type, norm_type="IN") + + self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels, + max_down_sampling_multiple, padding_mode, activation_type, + norm_type="NONE") + + content_channels = base_channels * (2 ** max_down_sampling_multiple) + + self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2, + num_mlp_base_feature, num_mlp_blocks, activation_type, + norm_type="NONE") + + self.decoder = Decoder(content_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks, + res_norm_type="AdaIN", norm_type="LN", activation_type=activation_type, + padding_mode=padding_mode) + + def encode(self, x): + return self.content_encoder(x), self.style_encoder(x) + + def decode(self, content, style): + self.decoder(content, self.fusion(style)) + + def forward(self, x): + content, style = self.encode(x) + return self.decode(content, style) diff --git a/model/image_translation/TSIT.py b/model/image_translation/TSIT.py new file mode 100644 index 0000000..e69de29 diff --git a/model/image_translation/UGATIT.py b/model/image_translation/UGATIT.py new file mode 100644 index 0000000..d288af7 --- /dev/null +++ b/model/image_translation/UGATIT.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from model import MODEL +from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock + + +class RhoClipper(object): + def __init__(self, clip_min, clip_max): + self.clip_min = clip_min + self.clip_max = clip_max + assert clip_min < clip_max + + def __call__(self, module): + if hasattr(module, 'rho'): + w = module.rho.data + w = w.clamp(self.clip_min, self.clip_max) + module.rho.data = w + + +class CAMClassifier(nn.Module): + def __init__(self, in_channels, activation_type="ReLU"): + super(CAMClassifier, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.avg_fc = nn.Linear(in_channels, 1, bias=False) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.max_fc = nn.Linear(in_channels, 1, bias=False) + self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, activation_type=activation_type, + norm_type="NONE", kernel_size=1, stride=1, bias=True) + + def forward(self, x): + avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1)) + max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1)) + + return self.fusion_conv(torch.cat( + [x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)], + dim=1 + )), torch.cat([avg_logit, max_logit], 1) + + +@MODEL.register_module("UGATIT-Generator") +class Generator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False, + activation_type="ReLU", norm_type="IN", padding_mode='reflect'): + super(Generator, self).__init__() + + self.light = light + + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )] + + n_down_sampling = 2 + for i in range(n_down_sampling): + mult = 2 ** i + sequence.append(Conv2dBlock( + base_channels * mult, base_channels * mult * 2, + kernel_size=3, stride=2, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )) + + mult = 2 ** n_down_sampling + sequence += [ + ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type, + norm_type=norm_type) + for _ in range(num_blocks)] + self.encoder = nn.Sequential(*sequence) + + self.cam = CAMClassifier(base_channels * mult, activation_type) + + # Gamma, Beta block + if self.light: + self.fc = nn.Sequential( + LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"), + LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE") + ) + else: + self.fc = nn.Sequential( + LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False, + "ReLU", "NONE"), + LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE") + ) + + self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False) + self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False) + + # Up-Sampling Bottleneck + self.up_bottleneck = nn.ModuleList( + [ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, + activation_type, norm_type="AdaILN") for _ in range(num_blocks)]) + + sequence = list() + channels = base_channels * mult + for i in range(n_down_sampling): + sequence.append(nn.Sequential( + nn.Upsample(scale_factor=2), + Conv2dBlock(channels, channels // 2, + kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode, + activation_type=activation_type, norm_type="ILN"), + )) + channels = channels // 2 + sequence.append(Conv2dBlock(channels, out_channels, + kernel_size=7, stride=1, padding=3, padding_mode="reflect", + activation_type="Tanh", norm_type="NONE")) + self.decoder = nn.Sequential(*sequence) + + def forward(self, x): + x = self.encoder(x) + + x, cam_logit = self.cam(x) + + heatmap = torch.sum(x, dim=1, keepdim=True) + + if self.light: + x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x_ = self.fc(x_.view(x_.shape[0], -1)) + else: + x_ = self.fc(x.view(x.shape[0], -1)) + gamma, beta = self.gamma(x_), self.beta(x_) + + for blk in self.up_bottleneck: + blk.conv1.normalization.set_condition(gamma, beta) + blk.conv2.normalization.set_condition(gamma, beta) + x = blk(x) + return self.decoder(x), cam_logit, heatmap + + +@MODEL.register_module("UGATIT-Discriminator") +class Discriminator(nn.Module): + def __init__(self, in_channels, base_channels=64, num_blocks=5, + activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'): + super().__init__() + + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )] + + sequence += [Conv2dBlock( + base_channels * (2 ** i), base_channels * (2 ** i) * 2, + kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)] + + sequence.append( + Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)), + kernel_size=4, stride=1, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type) + ) + self.sequence = nn.Sequential(*sequence) + + mult = 2 ** (num_blocks - 2) + self.cam = CAMClassifier(base_channels * mult, activation_type) + self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, + padding_mode="reflect") + + def forward(self, x, return_heatmap=False): + x = self.sequence(x) + + x, cam_logit = self.cam(x) + + if return_heatmap: + heatmap = torch.sum(x, dim=1, keepdim=True) + return self.conv(x), cam_logit, heatmap + else: + return self.conv(x), cam_logit diff --git a/model/image_translation/__init__.py b/model/image_translation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/misc.py b/util/misc.py index 80a61ca..7dfb5c7 100644 --- a/util/misc.py +++ b/util/misc.py @@ -8,7 +8,7 @@ import torch.nn as nn def add_spectral_norm(module): - if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'): + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'): return nn.utils.spectral_norm(module) else: return module