From a6ffab1445c6c121fb5b2fd7e83671cf63bd0638 Mon Sep 17 00:00:00 2001 From: budui Date: Tue, 13 Oct 2020 10:30:27 +0800 Subject: [PATCH] add image buffers for gan --- engine/util/container.py | 57 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/engine/util/container.py b/engine/util/container.py index c1690fd..5b6fb5e 100644 --- a/engine/util/container.py +++ b/engine/util/container.py @@ -1,3 +1,6 @@ +import torch + + class LossContainer: def __init__(self, weight, loss): self.weight = weight @@ -7,3 +10,57 @@ class LossContainer: if self.weight > 0: return self.weight * self.loss(*args, **kwargs) return 0.0 + + +class GANImageBuffer: + """This class implements an image buffer that stores previously + generated images. + This buffer allows us to update the discriminator using a history of + generated images rather than the ones produced by the latest generator + to reduce model oscillation. + Args: + buffer_size (int): The size of image buffer. If buffer_size = 0, + no buffer will be created. + buffer_ratio (float): The chance / possibility to use the images + previously stored in the buffer. + """ + + def __init__(self, buffer_size, buffer_ratio=0.5): + self.buffer_size = buffer_size + # create an empty buffer + if self.buffer_size > 0: + self.img_num = 0 + self.image_buffer = [] + self.buffer_ratio = buffer_ratio + + def query(self, images): + """Query current image batch using a history of generated images. + Args: + images (Tensor): Current image batch without history information. + """ + if self.buffer_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + # if the buffer is not full, keep inserting current images + if self.img_num < self.buffer_size: + self.img_num = self.img_num + 1 + self.image_buffer.append(image) + return_images.append(image) + else: + use_buffer = torch.rand(1) < self.buffer_ratio + # by self.buffer_ratio, the buffer will return a previously + # stored image, and insert the current image into the buffer + if use_buffer: + random_id = torch.randint(0, self.buffer_size, (1,)).item() + image_tmp = self.image_buffer[random_id].clone() + self.image_buffer[random_id] = image + return_images.append(image_tmp) + # by (1 - self.buffer_ratio), the buffer will return the + # current image + else: + return_images.append(image) + # collect all the images and return + return_images = torch.cat(return_images, 0) + return return_images