35 lines
1.4 KiB
Python
35 lines
1.4 KiB
Python
from torchvision import transforms
|
|
from torchvision.datasets.folder import default_loader
|
|
|
|
from .registry import TRANSFORM
|
|
|
|
# from https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html
|
|
_VALID_TORCHVISION_TRANSFORMS = ["ToTensor", "ToPILImage", "Normalize", "Resize",
|
|
"Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder",
|
|
"RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
|
|
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter",
|
|
"RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective",
|
|
"RandomErasing"]
|
|
|
|
for vtt in _VALID_TORCHVISION_TRANSFORMS:
|
|
TRANSFORM.register_module(module=getattr(transforms, vtt))
|
|
|
|
|
|
@TRANSFORM.register_module()
|
|
class Load:
|
|
def __init__(self, loader=default_loader):
|
|
self.loader = loader
|
|
|
|
def __call__(self, image_path):
|
|
return self.loader(image_path)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + "()"
|
|
|
|
|
|
def transform_pipeline(pipeline_description):
|
|
if pipeline_description is None or len(pipeline_description) == 0:
|
|
return lambda x: x
|
|
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
|
return transforms.Compose(transform_list)
|