skypilot

Форк
0
/
dummy_dataset.patch 
71 строка · 2.8 Кб
1
# A patch file to replace ImageNet with a dummy dataset.
2
# Use only for benchmarking purposes.
3

4
diff --git a/train.py b/train.py
5
index 6e3b058..8ddbcdd 100755
6
--- a/train.py
7
+++ b/train.py
8
@@ -61,6 +61,34 @@ except ImportError:
9
 torch.backends.cudnn.benchmark = True
10
 _logger = logging.getLogger('train')
11
 
12
+
13
+class DummyImageDataset(torch.utils.data.Dataset):
14
+    """Dummy dataset with synthetic images."""
15
+    _IMAGE_HEIGHT = 3072
16
+    _IMAGE_WIDTH = 2304
17
+
18
+    def __init__(self, num_images, num_classes):
19
+        import numpy as np
20
+        from PIL import Image
21
+        imarray = np.random.rand(self._IMAGE_HEIGHT, self._IMAGE_WIDTH, 3) * 255
22
+        self.img = Image.fromarray(imarray.astype('uint8')).convert('RGB')
23
+        self.num_images = num_images
24
+        self.num_classes = num_classes
25
+        self.transform = None
26
+        self.target_transform = None
27
+
28
+    def __len__(self):
29
+        return self.num_images
30
+
31
+    def __getitem__(self, idx):
32
+        if self.transform is not None:
33
+            img = self.transform(self.img)
34
+        target = idx % self.num_classes
35
+        if self.target_transform is not None:
36
+            target = self.target_transform(target)
37
+        return img, target
38
+
39
+
40
 # The first arg parser parses out only the --config argument, this argument is used to
41
 # load a yaml file containing key-values that override the defaults for the main parser below
42
 config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
43
@@ -71,8 +99,6 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
44
 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
45
 
46
 # Dataset parameters
47
-parser.add_argument('data_dir', metavar='DIR',
48
-                    help='path to dataset')
49
 parser.add_argument('--dataset', '-d', metavar='NAME', default='',
50
                     help='dataset type (default: ImageFolder/ImageTar if empty)')
51
 parser.add_argument('--train-split', metavar='NAME', default='train',
52
@@ -486,17 +512,8 @@ def main():
53
         _logger.info('Scheduled epochs: {}'.format(num_epochs))
54
 
55
     # create the train and eval datasets
56
-    dataset_train = create_dataset(
57
-        args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
58
-        class_map=args.class_map,
59
-        download=args.dataset_download,
60
-        batch_size=args.batch_size,
61
-        repeats=args.epoch_repeats)
62
-    dataset_eval = create_dataset(
63
-        args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
64
-        class_map=args.class_map,
65
-        download=args.dataset_download,
66
-        batch_size=args.batch_size)
67
+    dataset_train = DummyImageDataset(num_images=1231167, num_classes=1000)
68
+    dataset_eval = DummyImageDataset(num_images=50000, num_classes=1000)
69
 
70
     # setup mixup / cutmix
71
     collate_fn = None
72

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.