import torchvision.utils def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0): # merge image in a batch in `y` direction first. grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range, scale_each=scale_each, pad_value=pad_value) for img_batch in tensors] # merge images in `x` direction. return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids))