from torchvision import transforms from torchvision.datasets.folder import default_loader from .registry import TRANSFORM # from https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html _VALID_TORCHVISION_TRANSFORMS = ["ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing"] for vtt in _VALID_TORCHVISION_TRANSFORMS: TRANSFORM.register_module(module=getattr(transforms, vtt)) @TRANSFORM.register_module() class Load: def __init__(self, loader=default_loader): self.loader = loader def __call__(self, image_path): return self.loader(image_path) def __repr__(self): return self.__class__.__name__ + "()" def transform_pipeline(pipeline_description): if len(pipeline_description) == 0: return lambda x: x transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description] return transforms.Compose(transform_list)