|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from src.data.randaugment import RandomAugment |
|
|
|
normalize = transforms.Normalize( |
|
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) |
|
) |
|
|
|
|
|
class transform_train: |
|
def __init__(self, image_size=384, min_scale=0.5): |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
image_size, |
|
scale=(min_scale, 1.0), |
|
interpolation=InterpolationMode.BICUBIC, |
|
), |
|
transforms.RandomHorizontalFlip(), |
|
RandomAugment( |
|
2, |
|
5, |
|
isPIL=True, |
|
augs=[ |
|
"Identity", |
|
"AutoContrast", |
|
"Brightness", |
|
"Sharpness", |
|
"Equalize", |
|
"ShearX", |
|
"ShearY", |
|
"TranslateX", |
|
"TranslateY", |
|
"Rotate", |
|
], |
|
), |
|
transforms.ToTensor(), |
|
normalize, |
|
] |
|
) |
|
|
|
def __call__(self, img): |
|
return self.transform(img) |
|
|
|
|
|
class transform_test(transforms.Compose): |
|
def __init__(self, image_size=384): |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(image_size, image_size), |
|
interpolation=InterpolationMode.BICUBIC, |
|
), |
|
transforms.ToTensor(), |
|
normalize, |
|
] |
|
) |
|
|
|
def __call__(self, img): |
|
return self.transform(img) |
|
|