1
# A patch file to replace ImageNet with a dummy dataset.
2
# Use only for benchmarking purposes.
4
diff --git a/train.py b/train.py
5
index 6e3b058..8ddbcdd 100755
8
@@ -61,6 +61,34 @@ except ImportError:
9
torch.backends.cudnn.benchmark = True
10
_logger = logging.getLogger('train')
13
+class DummyImageDataset(torch.utils.data.Dataset):
14
+ """Dummy dataset with synthetic images."""
18
+ def __init__(self, num_images, num_classes):
20
+ from PIL import Image
21
+ imarray = np.random.rand(self._IMAGE_HEIGHT, self._IMAGE_WIDTH, 3) * 255
22
+ self.img = Image.fromarray(imarray.astype('uint8')).convert('RGB')
23
+ self.num_images = num_images
24
+ self.num_classes = num_classes
25
+ self.transform = None
26
+ self.target_transform = None
29
+ return self.num_images
31
+ def __getitem__(self, idx):
32
+ if self.transform is not None:
33
+ img = self.transform(self.img)
34
+ target = idx % self.num_classes
35
+ if self.target_transform is not None:
36
+ target = self.target_transform(target)
40
# The first arg parser parses out only the --config argument, this argument is used to
41
# load a yaml file containing key-values that override the defaults for the main parser below
42
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
43
@@ -71,8 +99,6 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
44
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
47
-parser.add_argument('data_dir', metavar='DIR',
48
- help='path to dataset')
49
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
50
help='dataset type (default: ImageFolder/ImageTar if empty)')
51
parser.add_argument('--train-split', metavar='NAME', default='train',
52
@@ -486,17 +512,8 @@ def main():
53
_logger.info('Scheduled epochs: {}'.format(num_epochs))
55
# create the train and eval datasets
56
- dataset_train = create_dataset(
57
- args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
58
- class_map=args.class_map,
59
- download=args.dataset_download,
60
- batch_size=args.batch_size,
61
- repeats=args.epoch_repeats)
62
- dataset_eval = create_dataset(
63
- args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
64
- class_map=args.class_map,
65
- download=args.dataset_download,
66
- batch_size=args.batch_size)
67
+ dataset_train = DummyImageDataset(num_images=1231167, num_classes=1000)
68
+ dataset_eval = DummyImageDataset(num_images=50000, num_classes=1000)
70
# setup mixup / cutmix