wandb

Форк
0
/
dsviz_demo.py 
515 строк · 18.2 Кб
1
import os
2
import pickle
3
import shutil
4
import time
5

6
import numpy as np
7
import wandb
8
from PIL import Image
9

10
WANDB_PROJECT_ENV = os.environ.get("WANDB_PROJECT")
11
if WANDB_PROJECT_ENV is None:
12
    WANDB_PROJECT = "test__" + str(round(time.time()) % 1000000)
13
else:
14
    WANDB_PROJECT = WANDB_PROJECT_ENV
15
os.environ["WANDB_PROJECT"] = WANDB_PROJECT
16

17
WANDB_SILENT_ENV = os.environ.get("WANDB_SILENT")
18
if WANDB_SILENT_ENV is None:
19
    WANDB_SILENT = "true"
20
else:
21
    WANDB_SILENT = WANDB_SILENT_ENV
22
os.environ["WANDB_SILENT"] = WANDB_SILENT
23

24
NUM_EXAMPLES = 10
25
DL_URL = "https://raw.githubusercontent.com/wandb/dsviz-demo/master/bdd20_small.tgz"  # "https://storage.googleapis.com/l2kzone/bdd100k.tgz"
26
LOCAL_FOLDER_NAME = "bdd20_small"  # "bdd100k"
27
LOCAL_ASSET_NAME = f"{LOCAL_FOLDER_NAME}.tgz"
28

29

30
BDD_CLASSES = [
31
    "road",
32
    "sidewalk",
33
    "building",
34
    "wall",
35
    "fence",
36
    "pole",
37
    "traffic light",
38
    "traffic sign",
39
    "vegetation",
40
    "terrain",
41
    "sky",
42
    "person",
43
    "rider",
44
    "car",
45
    "truck",
46
    "bus",
47
    "train",
48
    "motorcycle",
49
    "bicycle",
50
    "void",
51
]
52
BDD_IDS = list(range(len(BDD_CLASSES) - 1)) + [255]
53
BDD_ID_MAP = {id: ndx for ndx, id in enumerate(BDD_IDS)}
54

55
n_classes = len(BDD_CLASSES)
56
bdd_dir = os.path.join(".", LOCAL_FOLDER_NAME, "seg")
57
train_dir = os.path.join(bdd_dir, "images", "train")
58
color_labels_dir = os.path.join(bdd_dir, "color_labels", "train")
59
labels_dir = os.path.join(bdd_dir, "labels", "train")
60

61
train_ids = None
62

63

64
def cleanup():
65
    if os.path.isdir("artifacts"):
66
        shutil.rmtree("artifacts")
67

68
    if os.path.isdir(LOCAL_FOLDER_NAME):
69
        shutil.rmtree(LOCAL_FOLDER_NAME)
70

71
    if os.path.isdir("wandb"):
72
        shutil.rmtree("wandb")
73

74
    if os.path.isfile(LOCAL_ASSET_NAME):
75
        os.remove(LOCAL_ASSET_NAME)
76

77
    if os.path.isfile("model.pkl"):
78
        os.remove("model.pkl")
79

80

81
def download_data():
82
    global train_ids
83
    if not os.path.exists(LOCAL_ASSET_NAME):
84
        os.system(f"curl {DL_URL} --output {LOCAL_ASSET_NAME}")
85

86
    if not os.path.exists(LOCAL_FOLDER_NAME):
87
        os.system(f"tar xzf {LOCAL_ASSET_NAME}")
88

89
    train_ids = [
90
        name.split(".")[0] for name in os.listdir(train_dir) if name.split(".")[0] != ""
91
    ]
92

93

94
def _check_train_ids():
95
    if train_ids is None:
96
        raise Exception(
97
            "Please download the data using download_data() before attempting to access it."
98
        )
99

100

101
def get_train_image_path(ndx):
102
    _check_train_ids()
103
    return os.path.join(train_dir, train_ids[ndx] + ".jpg")
104

105

106
def get_color_label_image_path(ndx):
107
    _check_train_ids()
108
    return os.path.join(color_labels_dir, train_ids[ndx] + "_train_color.png")
109

110

111
def get_label_image_path(ndx):
112
    _check_train_ids()
113
    return os.path.join(labels_dir, train_ids[ndx] + "_train_id.png")
114

115

116
def get_dominant_id_ndx(np_image):
117
    if isinstance(np_image, wandb.Image):
118
        np_image = np.array(np_image.image)
119
    return BDD_ID_MAP[np.argmax(np.bincount(np_image.astype(int).flatten()))]
120

121

122
def clean_artifacts_dir():
123
    if os.path.isdir("artifacts"):
124
        shutil.rmtree("artifacts")
125

126

127
def mask_to_bounding(np_image):
128
    if isinstance(np_image, wandb.Image):
129
        np_image = np.array(np_image.image)
130

131
    data = []
132
    for id_num in BDD_IDS:
133
        matches = np_image == id_num
134
        col_count = np.where(matches.sum(axis=0))[0]
135
        row_count = np.where(matches.sum(axis=1))[0]
136

137
        if len(col_count) > 1 and len(row_count) > 1:
138
            min_x = col_count[0] / np_image.shape[1]
139
            max_x = col_count[-1] / np_image.shape[1]
140
            min_y = row_count[0] / np_image.shape[0]
141
            max_y = row_count[-1] / np_image.shape[0]
142

143
            data.append(
144
                {
145
                    "position": {
146
                        "minX": min_x,
147
                        "maxX": max_x,
148
                        "minY": min_y,
149
                        "maxY": max_y,
150
                    },
151
                    "class_id": id_num,
152
                }
153
            )
154
    return data
155

156

157
def get_scaled_train_image(ndx, factor=2):
158
    return Image.open(get_train_image_path(ndx)).reduce(factor)
159

160

161
def get_scaled_mask_label(ndx, factor=2):
162
    return np.array(Image.open(get_label_image_path(ndx)).reduce(factor))
163

164

165
def get_scaled_bounding_boxes(ndx, factor=2):
166
    return mask_to_bounding(
167
        np.array(Image.open(get_label_image_path(ndx)).reduce(factor))
168
    )
169

170

171
def get_scaled_color_mask(ndx, factor=2):
172
    return Image.open(get_color_label_image_path(ndx)).reduce(factor)
173

174

175
def get_dominant_class(label_mask):
176
    return BDD_CLASSES[get_dominant_id_ndx(label_mask)]
177

178

179
class ExampleSegmentationModel:
180
    def __init__(self, n_classes):
181
        self.n_classes = n_classes
182

183
    def train(self, images, masks):
184
        self.min = images.min()
185
        self.max = images.max()
186
        images = (images - self.min) / (self.max - self.min)
187
        step = 1.0 / n_classes
188
        self.quantiles = list(
189
            np.quantile(images, [i * step for i in range(self.n_classes)])
190
        )
191
        self.quantiles.append(1.0)
192
        self.outshape = masks.shape
193

194
    def predict(self, images):
195
        results = np.zeros((images.shape[0], self.outshape[1], self.outshape[2]))
196
        images = ((images - self.min) / (self.max - self.min)).mean(axis=3)
197
        for i in range(self.n_classes):
198
            results[
199
                (self.quantiles[i] < images) & (images <= self.quantiles[i + 1])
200
            ] = BDD_IDS[i]
201
        return results
202

203
    def save(self, file_path):
204
        with open(file_path, "wb") as file:
205
            pickle.dump(self, file)
206

207
    @staticmethod
208
    def load(file_path):
209
        model = None
210
        with open(file_path, "rb") as file:
211
            model = pickle.load(file)
212
        return model
213

214

215
def iou(mask_a, mask_b, class_id):
216
    return np.nan_to_num(
217
        ((mask_a == class_id) & (mask_b == class_id)).sum(axis=(1, 2))
218
        / ((mask_a == class_id) | (mask_b == class_id)).sum(axis=(1, 2)),
219
        0,
220
        0,
221
        0,
222
    )
223

224

225
def score_model(model, x_data, mask_data, n_classes):
226
    results = model.predict(x_data)
227
    return np.array([iou(results, mask_data, i) for i in BDD_IDS]).T, results
228

229

230
def make_datasets(data_table, n_classes):
231
    n_samples = len(data_table.data)
232
    # n_classes = len(BDD_CLASSES)
233
    height = data_table.data[0][1].image.height
234
    width = data_table.data[0][1].image.width
235

236
    train_data = np.array(
237
        [
238
            np.array(data_table.data[i][1].image).reshape(height, width, 3)
239
            for i in range(n_samples)
240
        ]
241
    )
242
    mask_data = np.array(
243
        [
244
            np.array(data_table.data[i][3].image).reshape(height, width)
245
            for i in range(n_samples)
246
        ]
247
    )
248
    return train_data, mask_data
249

250

251
def main():
252
    try:
253
        # Download the data if not already
254
        download_data()
255

256
        # Initialize the run
257
        with wandb.init(
258
            project=WANDB_PROJECT,  # The project to register this Run to
259
            job_type="create_dataset",  # The type of this Run. Runs of the same type can be grouped together in the UI
260
            config={  # Custom configuration parameters which you might want to tune or adjust for the Run
261
                "num_examples": NUM_EXAMPLES,  # The number of raw samples to include.
262
                "scale_factor": 2,  # The scaling factor for the images
263
            },
264
        ) as run:
265
            # Setup a WandB Classes object. This will give additional metadata for visuals
266
            class_set = wandb.Classes(
267
                [{"name": name, "id": id} for name, id in zip(BDD_CLASSES, BDD_IDS)]
268
            )
269

270
            # Setup a WandB Table object to hold our dataset
271
            table = wandb.Table(
272
                columns=[
273
                    "id",
274
                    "train_image",
275
                    "colored_image",
276
                    "label_mask",
277
                    "dominant_class",
278
                ]
279
            )
280

281
            # Fill up the table
282
            for ndx in range(run.config["num_examples"]):
283
                # First, we will build a wandb.Image to act as our raw example object
284
                #    classes: the classes which map to masks and/or box metadata
285
                #    masks: the mask metadata. In this case, we use a 2d array where each cell corresponds to the label (this comes directlyfrom the dataset)
286
                #    boxes: the bounding box metadata. For example sake, we create bounding boxes by looking at the mask data and creating boxes which fully encolose each class.
287
                #           The data is an array of objects like:
288
                #                 "position": {
289
                #                             "minX": minX,
290
                #                             "maxX": maxX,
291
                #                             "minY": minY,
292
                #                             "maxY": maxY,
293
                #                         },
294
                #                         "class_id" : id_num,
295
                #                     }
296
                example = wandb.Image(
297
                    get_scaled_train_image(ndx, run.config.scale_factor),
298
                    classes=class_set,
299
                    masks={
300
                        "ground_truth": {
301
                            "mask_data": get_scaled_mask_label(
302
                                ndx, run.config.scale_factor
303
                            )
304
                        },
305
                    },
306
                    boxes={
307
                        "ground_truth": {
308
                            "box_data": get_scaled_bounding_boxes(
309
                                ndx, run.config.scale_factor
310
                            )
311
                        }
312
                    },
313
                )
314

315
                # Next, we create two additional images which may be helpful during analysis. Notice that the additional metadata is optional.
316
                color_label = wandb.Image(
317
                    get_scaled_color_mask(ndx, run.config.scale_factor)
318
                )
319
                label_mask = wandb.Image(
320
                    get_scaled_mask_label(ndx, run.config.scale_factor)
321
                )
322

323
                # Finally, we add a row of our newly constructed data.
324
                table.add_data(
325
                    train_ids[ndx],
326
                    example,
327
                    color_label,
328
                    label_mask,
329
                    get_dominant_class(label_mask),
330
                )
331

332
            # Create an Artifact (versioned folder)
333
            artifact = wandb.Artifact(name="raw_data", type="dataset")
334

335
            # add the table to the artifact
336
            artifact.add(table, "raw_examples")
337

338
            # Finally, log the artifact
339
            run.log_artifact(artifact)
340
        print("Step 1/5 Complete")
341

342
        # This step should look familiar by now:
343
        with wandb.init(
344
            project=WANDB_PROJECT,
345
            job_type="split_dataset",
346
            config={
347
                "train_pct": 0.7,
348
            },
349
        ) as run:
350
            # Get the latest version of the artifact. Notice the name alias follows this convention: "<ARTIFACT_NAME>:<VERSION>"
351
            # when version is set to "latest", then the latest version will always be used. However, you can pin to a version by
352
            # using an alias such as "raw_data:v0"
353
            dataset_artifact = run.use_artifact("raw_data:latest")
354

355
            # Next, we "get" the table by the same name that we saved it in the last run.
356
            data_table = dataset_artifact.get("raw_examples")
357

358
            # Now we can build two separate artifacts for later use. We will first split the raw table into two parts,
359
            # then create two different artifacts, each of which will hold our new tables. We create two artifacts so that
360
            # in future runs, we can selectively decide which subsets of data to download.
361

362
            # Create the tables
363
            train_count = int(len(data_table.data) * run.config.train_pct)
364
            train_table = wandb.Table(
365
                columns=data_table.columns, data=data_table.data[:train_count]
366
            )
367
            test_table = wandb.Table(
368
                columns=data_table.columns, data=data_table.data[train_count:]
369
            )
370

371
            # Create the artifacts
372
            train_artifact = wandb.Artifact("train_data", "dataset")
373
            test_artifact = wandb.Artifact("test_data", "dataset")
374

375
            # Save the tables to the artifacts
376
            train_artifact.add(train_table, "train_table")
377
            test_artifact.add(test_table, "test_table")
378

379
            # Log the artifacts out as outputs of the run
380
            run.log_artifact(train_artifact)
381
            run.log_artifact(test_artifact)
382
        print("Step 2/5 Complete")
383

384
        # Again, create a run.
385
        with wandb.init(project=WANDB_PROJECT, job_type="model_train") as run:
386
            # Similar to before, we will load in the artifact and asset we need. In this case, the training data
387
            train_artifact = run.use_artifact("train_data:latest")
388
            train_table = train_artifact.get("train_table")
389

390
            # Next, we split out the labels and train the model
391
            train_data, mask_data = make_datasets(train_table, n_classes)
392
            model = ExampleSegmentationModel(n_classes)
393
            model.train(train_data, mask_data)
394

395
            # Finally we score the model. Behind the scenes, we score each mask on it's IOU score.
396
            scores, results = score_model(model, train_data, mask_data, n_classes)
397

398
            # Let's create a new table. Notice that we create many columns - an evaluation score for each class type.
399
            results_table = wandb.Table(
400
                columns=["id", "pred_mask", "dominant_pred"] + BDD_CLASSES,
401
                # Data construction is similar to before, but we now use the predicted masks and bound boxes.
402
                data=[
403
                    [
404
                        train_table.data[ndx][0],
405
                        wandb.Image(
406
                            train_table.data[ndx][1],
407
                            masks={
408
                                "train_predicted_truth": {
409
                                    "mask_data": results[ndx],
410
                                },
411
                            },
412
                            boxes={
413
                                "ground_truth": {
414
                                    "box_data": mask_to_bounding(results[ndx])
415
                                }
416
                            },
417
                        ),
418
                        BDD_CLASSES[get_dominant_id_ndx(results[ndx])],
419
                    ]
420
                    + list(row)
421
                    for ndx, row in enumerate(scores)
422
                ],
423
            )
424

425
            # We create an artifact, add the table, and log it as part of the run.
426
            results_artifact = wandb.Artifact("train_results", "dataset")
427
            results_artifact.add(results_table, "train_iou_score_table")
428
            run.log_artifact(results_artifact)
429

430
            # Finally, let's save the model as a flat file and add that to it's own artifact.
431
            model.save("model.pkl")
432
            model_artifact = wandb.Artifact("trained_model", "model")
433
            model_artifact.add_file("model.pkl")
434
            run.log_artifact(model_artifact)
435
        print("Step 3/5 Complete")
436

437
        with wandb.init(project=WANDB_PROJECT, job_type="model_eval") as run:
438
            # Retrieve the test data
439
            test_artifact = run.use_artifact("test_data:latest")
440
            test_table = test_artifact.get("test_table")
441
            test_data, mask_data = make_datasets(test_table, n_classes)
442

443
            # Download the saved model file.
444
            model_artifact = run.use_artifact("trained_model:latest")
445
            path = model_artifact.get_entry("model.pkl").download()
446

447
            # Load the model from the file and score it
448
            model = ExampleSegmentationModel.load(path)
449
            scores, results = score_model(model, test_data, mask_data, n_classes)
450

451
            # Create a predicted score table similar to step 3.
452
            results_artifact = wandb.Artifact("test_results", "dataset")
453
            data = [
454
                [
455
                    test_table.data[ndx][0],
456
                    wandb.Image(
457
                        test_table.data[ndx][1],
458
                        masks={
459
                            "test_predicted_truth": {
460
                                "mask_data": results[ndx],
461
                            },
462
                        },
463
                        boxes={
464
                            "ground_truth": {"box_data": mask_to_bounding(results[ndx])}
465
                        },
466
                    ),
467
                    BDD_CLASSES[get_dominant_id_ndx(results[ndx])],
468
                ]
469
                + list(row)
470
                for ndx, row in enumerate(scores)
471
            ]
472

473
            # And log out the results.
474
            results_artifact.add(
475
                wandb.Table(
476
                    ["id", "pred_mask_test", "dominant_pred_test"] + BDD_CLASSES,
477
                    data=data,
478
                ),
479
                "test_iou_score_table",
480
            )
481
            run.log_artifact(results_artifact)
482
        print("Step 4/5 Complete")
483

484
        with wandb.init(project=WANDB_PROJECT, job_type="model_result_analysis") as run:
485
            # Retrieve the original raw dataset
486
            dataset_artifact = run.use_artifact("raw_data:latest")
487
            data_table = dataset_artifact.get("raw_examples")
488

489
            # Retrieve the train and test score tables
490
            train_artifact = run.use_artifact("train_results:latest")
491
            train_table = train_artifact.get("train_iou_score_table")
492

493
            test_artifact = run.use_artifact("test_results:latest")
494
            test_table = test_artifact.get("test_iou_score_table")
495

496
            # Join the tables on ID column and log them as outputs.
497
            train_results = wandb.JoinedTable(train_table, data_table, "id")
498
            test_results = wandb.JoinedTable(test_table, data_table, "id")
499
            artifact = wandb.Artifact("summary_results", "dataset")
500
            artifact.add(train_results, "train_results")
501
            artifact.add(test_results, "test_results")
502
            run.log_artifact(artifact)
503
        print("Step 5/5 Complete")
504

505
        if WANDB_PROJECT_ENV is not None:
506
            os.environ["WANDB_PROJECT"] = WANDB_PROJECT_ENV
507

508
        if WANDB_SILENT_ENV is not None:
509
            os.environ["WANDB_SILENT"] = WANDB_SILENT_ENV
510
    finally:
511
        cleanup()
512

513

514
if __name__ == "__main__":
515
    main()
516

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.