| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import PIL |
|
|
| from torchvision import datasets, transforms |
|
|
| from timm.data import create_transform |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
|
|
|
| def build_dataset(is_train, args): |
| transform = build_transform(is_train, args) |
|
|
| root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| dataset = datasets.ImageFolder(root, transform=transform) |
|
|
| print(dataset) |
|
|
| return dataset |
|
|
|
|
| def build_transform(is_train, args): |
| mean = IMAGENET_DEFAULT_MEAN |
| std = IMAGENET_DEFAULT_STD |
| |
| if is_train: |
| |
| transform = create_transform( |
| input_size=args.input_size, |
| is_training=True, |
| color_jitter=args.color_jitter, |
| auto_augment=args.aa, |
| interpolation='bicubic', |
| re_prob=args.reprob, |
| re_mode=args.remode, |
| re_count=args.recount, |
| mean=mean, |
| std=std, |
| ) |
| return transform |
|
|
| |
| t = [] |
| if args.input_size <= 224: |
| crop_pct = 224 / 256 |
| else: |
| crop_pct = 1.0 |
| size = int(args.input_size / crop_pct) |
| t.append( |
| transforms.Resize(size, interpolation=PIL.Image.BICUBIC), |
| ) |
| t.append(transforms.CenterCrop(args.input_size)) |
|
|
| t.append(transforms.ToTensor()) |
| t.append(transforms.Normalize(mean, std)) |
| return transforms.Compose(t) |
|
|