raycv/data/transform.py

35 lines
1.4 KiB
Python

from torchvision import transforms
from torchvision.datasets.folder import default_loader
from .registry import TRANSFORM
# from https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html
_VALID_TORCHVISION_TRANSFORMS = ["ToTensor", "ToPILImage", "Normalize", "Resize",
"Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder",
"RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter",
"RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective",
"RandomErasing"]
for vtt in _VALID_TORCHVISION_TRANSFORMS:
TRANSFORM.register_module(module=getattr(transforms, vtt))
@TRANSFORM.register_module()
class Load:
def __init__(self, loader=default_loader):
self.loader = loader
def __call__(self, image_path):
return self.loader(image_path)
def __repr__(self):
return self.__class__.__name__ + "()"
def transform_pipeline(pipeline_description):
if len(pipeline_description) == 0:
return lambda x: x
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
return transforms.Compose(transform_list)