import torch class LossContainer: def __init__(self, weight, loss): self.weight = weight self.loss = loss def __call__(self, *args, **kwargs): if self.weight > 0: return self.weight * self.loss(*args, **kwargs) return 0.0 class GANImageBuffer: """This class implements an image buffer that stores previously generated images. This buffer allows us to update the discriminator using a history of generated images rather than the ones produced by the latest generator to reduce model oscillation. Args: buffer_size (int): The size of image buffer. If buffer_size = 0, no buffer will be created. buffer_ratio (float): The chance / possibility to use the images previously stored in the buffer. """ def __init__(self, buffer_size, buffer_ratio=0.5): self.buffer_size = buffer_size # create an empty buffer if self.buffer_size > 0: self.img_num = 0 self.image_buffer = [] self.buffer_ratio = buffer_ratio def query(self, images): """Query current image batch using a history of generated images. Args: images (Tensor): Current image batch without history information. """ if self.buffer_size == 0: # if the buffer size is 0, do nothing return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) # if the buffer is not full, keep inserting current images if self.img_num < self.buffer_size: self.img_num = self.img_num + 1 self.image_buffer.append(image) return_images.append(image) else: use_buffer = torch.rand(1) < self.buffer_ratio # by self.buffer_ratio, the buffer will return a previously # stored image, and insert the current image into the buffer if use_buffer: random_id = torch.randint(0, self.buffer_size, (1,)).item() image_tmp = self.image_buffer[random_id].clone() self.image_buffer[random_id] = image return_images.append(image_tmp) # by (1 - self.buffer_ratio), the buffer will return the # current image else: return_images.append(image) # collect all the images and return return_images = torch.cat(return_images, 0) return return_images