wandb
515 строк · 18.2 Кб
1import os2import pickle3import shutil4import time5
6import numpy as np7import wandb8from PIL import Image9
10WANDB_PROJECT_ENV = os.environ.get("WANDB_PROJECT")11if WANDB_PROJECT_ENV is None:12WANDB_PROJECT = "test__" + str(round(time.time()) % 1000000)13else:14WANDB_PROJECT = WANDB_PROJECT_ENV15os.environ["WANDB_PROJECT"] = WANDB_PROJECT16
17WANDB_SILENT_ENV = os.environ.get("WANDB_SILENT")18if WANDB_SILENT_ENV is None:19WANDB_SILENT = "true"20else:21WANDB_SILENT = WANDB_SILENT_ENV22os.environ["WANDB_SILENT"] = WANDB_SILENT23
24NUM_EXAMPLES = 1025DL_URL = "https://raw.githubusercontent.com/wandb/dsviz-demo/master/bdd20_small.tgz" # "https://storage.googleapis.com/l2kzone/bdd100k.tgz"26LOCAL_FOLDER_NAME = "bdd20_small" # "bdd100k"27LOCAL_ASSET_NAME = f"{LOCAL_FOLDER_NAME}.tgz"28
29
30BDD_CLASSES = [31"road",32"sidewalk",33"building",34"wall",35"fence",36"pole",37"traffic light",38"traffic sign",39"vegetation",40"terrain",41"sky",42"person",43"rider",44"car",45"truck",46"bus",47"train",48"motorcycle",49"bicycle",50"void",51]
52BDD_IDS = list(range(len(BDD_CLASSES) - 1)) + [255]53BDD_ID_MAP = {id: ndx for ndx, id in enumerate(BDD_IDS)}54
55n_classes = len(BDD_CLASSES)56bdd_dir = os.path.join(".", LOCAL_FOLDER_NAME, "seg")57train_dir = os.path.join(bdd_dir, "images", "train")58color_labels_dir = os.path.join(bdd_dir, "color_labels", "train")59labels_dir = os.path.join(bdd_dir, "labels", "train")60
61train_ids = None62
63
64def cleanup():65if os.path.isdir("artifacts"):66shutil.rmtree("artifacts")67
68if os.path.isdir(LOCAL_FOLDER_NAME):69shutil.rmtree(LOCAL_FOLDER_NAME)70
71if os.path.isdir("wandb"):72shutil.rmtree("wandb")73
74if os.path.isfile(LOCAL_ASSET_NAME):75os.remove(LOCAL_ASSET_NAME)76
77if os.path.isfile("model.pkl"):78os.remove("model.pkl")79
80
81def download_data():82global train_ids83if not os.path.exists(LOCAL_ASSET_NAME):84os.system(f"curl {DL_URL} --output {LOCAL_ASSET_NAME}")85
86if not os.path.exists(LOCAL_FOLDER_NAME):87os.system(f"tar xzf {LOCAL_ASSET_NAME}")88
89train_ids = [90name.split(".")[0] for name in os.listdir(train_dir) if name.split(".")[0] != ""91]92
93
94def _check_train_ids():95if train_ids is None:96raise Exception(97"Please download the data using download_data() before attempting to access it."98)99
100
101def get_train_image_path(ndx):102_check_train_ids()103return os.path.join(train_dir, train_ids[ndx] + ".jpg")104
105
106def get_color_label_image_path(ndx):107_check_train_ids()108return os.path.join(color_labels_dir, train_ids[ndx] + "_train_color.png")109
110
111def get_label_image_path(ndx):112_check_train_ids()113return os.path.join(labels_dir, train_ids[ndx] + "_train_id.png")114
115
116def get_dominant_id_ndx(np_image):117if isinstance(np_image, wandb.Image):118np_image = np.array(np_image.image)119return BDD_ID_MAP[np.argmax(np.bincount(np_image.astype(int).flatten()))]120
121
122def clean_artifacts_dir():123if os.path.isdir("artifacts"):124shutil.rmtree("artifacts")125
126
127def mask_to_bounding(np_image):128if isinstance(np_image, wandb.Image):129np_image = np.array(np_image.image)130
131data = []132for id_num in BDD_IDS:133matches = np_image == id_num134col_count = np.where(matches.sum(axis=0))[0]135row_count = np.where(matches.sum(axis=1))[0]136
137if len(col_count) > 1 and len(row_count) > 1:138min_x = col_count[0] / np_image.shape[1]139max_x = col_count[-1] / np_image.shape[1]140min_y = row_count[0] / np_image.shape[0]141max_y = row_count[-1] / np_image.shape[0]142
143data.append(144{145"position": {146"minX": min_x,147"maxX": max_x,148"minY": min_y,149"maxY": max_y,150},151"class_id": id_num,152}153)154return data155
156
157def get_scaled_train_image(ndx, factor=2):158return Image.open(get_train_image_path(ndx)).reduce(factor)159
160
161def get_scaled_mask_label(ndx, factor=2):162return np.array(Image.open(get_label_image_path(ndx)).reduce(factor))163
164
165def get_scaled_bounding_boxes(ndx, factor=2):166return mask_to_bounding(167np.array(Image.open(get_label_image_path(ndx)).reduce(factor))168)169
170
171def get_scaled_color_mask(ndx, factor=2):172return Image.open(get_color_label_image_path(ndx)).reduce(factor)173
174
175def get_dominant_class(label_mask):176return BDD_CLASSES[get_dominant_id_ndx(label_mask)]177
178
179class ExampleSegmentationModel:180def __init__(self, n_classes):181self.n_classes = n_classes182
183def train(self, images, masks):184self.min = images.min()185self.max = images.max()186images = (images - self.min) / (self.max - self.min)187step = 1.0 / n_classes188self.quantiles = list(189np.quantile(images, [i * step for i in range(self.n_classes)])190)191self.quantiles.append(1.0)192self.outshape = masks.shape193
194def predict(self, images):195results = np.zeros((images.shape[0], self.outshape[1], self.outshape[2]))196images = ((images - self.min) / (self.max - self.min)).mean(axis=3)197for i in range(self.n_classes):198results[199(self.quantiles[i] < images) & (images <= self.quantiles[i + 1])200] = BDD_IDS[i]201return results202
203def save(self, file_path):204with open(file_path, "wb") as file:205pickle.dump(self, file)206
207@staticmethod208def load(file_path):209model = None210with open(file_path, "rb") as file:211model = pickle.load(file)212return model213
214
215def iou(mask_a, mask_b, class_id):216return np.nan_to_num(217((mask_a == class_id) & (mask_b == class_id)).sum(axis=(1, 2))218/ ((mask_a == class_id) | (mask_b == class_id)).sum(axis=(1, 2)),2190,2200,2210,222)223
224
225def score_model(model, x_data, mask_data, n_classes):226results = model.predict(x_data)227return np.array([iou(results, mask_data, i) for i in BDD_IDS]).T, results228
229
230def make_datasets(data_table, n_classes):231n_samples = len(data_table.data)232# n_classes = len(BDD_CLASSES)233height = data_table.data[0][1].image.height234width = data_table.data[0][1].image.width235
236train_data = np.array(237[238np.array(data_table.data[i][1].image).reshape(height, width, 3)239for i in range(n_samples)240]241)242mask_data = np.array(243[244np.array(data_table.data[i][3].image).reshape(height, width)245for i in range(n_samples)246]247)248return train_data, mask_data249
250
251def main():252try:253# Download the data if not already254download_data()255
256# Initialize the run257with wandb.init(258project=WANDB_PROJECT, # The project to register this Run to259job_type="create_dataset", # The type of this Run. Runs of the same type can be grouped together in the UI260config={ # Custom configuration parameters which you might want to tune or adjust for the Run261"num_examples": NUM_EXAMPLES, # The number of raw samples to include.262"scale_factor": 2, # The scaling factor for the images263},264) as run:265# Setup a WandB Classes object. This will give additional metadata for visuals266class_set = wandb.Classes(267[{"name": name, "id": id} for name, id in zip(BDD_CLASSES, BDD_IDS)]268)269
270# Setup a WandB Table object to hold our dataset271table = wandb.Table(272columns=[273"id",274"train_image",275"colored_image",276"label_mask",277"dominant_class",278]279)280
281# Fill up the table282for ndx in range(run.config["num_examples"]):283# First, we will build a wandb.Image to act as our raw example object284# classes: the classes which map to masks and/or box metadata285# masks: the mask metadata. In this case, we use a 2d array where each cell corresponds to the label (this comes directlyfrom the dataset)286# boxes: the bounding box metadata. For example sake, we create bounding boxes by looking at the mask data and creating boxes which fully encolose each class.287# The data is an array of objects like:288# "position": {289# "minX": minX,290# "maxX": maxX,291# "minY": minY,292# "maxY": maxY,293# },294# "class_id" : id_num,295# }296example = wandb.Image(297get_scaled_train_image(ndx, run.config.scale_factor),298classes=class_set,299masks={300"ground_truth": {301"mask_data": get_scaled_mask_label(302ndx, run.config.scale_factor303)304},305},306boxes={307"ground_truth": {308"box_data": get_scaled_bounding_boxes(309ndx, run.config.scale_factor310)311}312},313)314
315# Next, we create two additional images which may be helpful during analysis. Notice that the additional metadata is optional.316color_label = wandb.Image(317get_scaled_color_mask(ndx, run.config.scale_factor)318)319label_mask = wandb.Image(320get_scaled_mask_label(ndx, run.config.scale_factor)321)322
323# Finally, we add a row of our newly constructed data.324table.add_data(325train_ids[ndx],326example,327color_label,328label_mask,329get_dominant_class(label_mask),330)331
332# Create an Artifact (versioned folder)333artifact = wandb.Artifact(name="raw_data", type="dataset")334
335# add the table to the artifact336artifact.add(table, "raw_examples")337
338# Finally, log the artifact339run.log_artifact(artifact)340print("Step 1/5 Complete")341
342# This step should look familiar by now:343with wandb.init(344project=WANDB_PROJECT,345job_type="split_dataset",346config={347"train_pct": 0.7,348},349) as run:350# Get the latest version of the artifact. Notice the name alias follows this convention: "<ARTIFACT_NAME>:<VERSION>"351# when version is set to "latest", then the latest version will always be used. However, you can pin to a version by352# using an alias such as "raw_data:v0"353dataset_artifact = run.use_artifact("raw_data:latest")354
355# Next, we "get" the table by the same name that we saved it in the last run.356data_table = dataset_artifact.get("raw_examples")357
358# Now we can build two separate artifacts for later use. We will first split the raw table into two parts,359# then create two different artifacts, each of which will hold our new tables. We create two artifacts so that360# in future runs, we can selectively decide which subsets of data to download.361
362# Create the tables363train_count = int(len(data_table.data) * run.config.train_pct)364train_table = wandb.Table(365columns=data_table.columns, data=data_table.data[:train_count]366)367test_table = wandb.Table(368columns=data_table.columns, data=data_table.data[train_count:]369)370
371# Create the artifacts372train_artifact = wandb.Artifact("train_data", "dataset")373test_artifact = wandb.Artifact("test_data", "dataset")374
375# Save the tables to the artifacts376train_artifact.add(train_table, "train_table")377test_artifact.add(test_table, "test_table")378
379# Log the artifacts out as outputs of the run380run.log_artifact(train_artifact)381run.log_artifact(test_artifact)382print("Step 2/5 Complete")383
384# Again, create a run.385with wandb.init(project=WANDB_PROJECT, job_type="model_train") as run:386# Similar to before, we will load in the artifact and asset we need. In this case, the training data387train_artifact = run.use_artifact("train_data:latest")388train_table = train_artifact.get("train_table")389
390# Next, we split out the labels and train the model391train_data, mask_data = make_datasets(train_table, n_classes)392model = ExampleSegmentationModel(n_classes)393model.train(train_data, mask_data)394
395# Finally we score the model. Behind the scenes, we score each mask on it's IOU score.396scores, results = score_model(model, train_data, mask_data, n_classes)397
398# Let's create a new table. Notice that we create many columns - an evaluation score for each class type.399results_table = wandb.Table(400columns=["id", "pred_mask", "dominant_pred"] + BDD_CLASSES,401# Data construction is similar to before, but we now use the predicted masks and bound boxes.402data=[403[404train_table.data[ndx][0],405wandb.Image(406train_table.data[ndx][1],407masks={408"train_predicted_truth": {409"mask_data": results[ndx],410},411},412boxes={413"ground_truth": {414"box_data": mask_to_bounding(results[ndx])415}416},417),418BDD_CLASSES[get_dominant_id_ndx(results[ndx])],419]420+ list(row)421for ndx, row in enumerate(scores)422],423)424
425# We create an artifact, add the table, and log it as part of the run.426results_artifact = wandb.Artifact("train_results", "dataset")427results_artifact.add(results_table, "train_iou_score_table")428run.log_artifact(results_artifact)429
430# Finally, let's save the model as a flat file and add that to it's own artifact.431model.save("model.pkl")432model_artifact = wandb.Artifact("trained_model", "model")433model_artifact.add_file("model.pkl")434run.log_artifact(model_artifact)435print("Step 3/5 Complete")436
437with wandb.init(project=WANDB_PROJECT, job_type="model_eval") as run:438# Retrieve the test data439test_artifact = run.use_artifact("test_data:latest")440test_table = test_artifact.get("test_table")441test_data, mask_data = make_datasets(test_table, n_classes)442
443# Download the saved model file.444model_artifact = run.use_artifact("trained_model:latest")445path = model_artifact.get_entry("model.pkl").download()446
447# Load the model from the file and score it448model = ExampleSegmentationModel.load(path)449scores, results = score_model(model, test_data, mask_data, n_classes)450
451# Create a predicted score table similar to step 3.452results_artifact = wandb.Artifact("test_results", "dataset")453data = [454[455test_table.data[ndx][0],456wandb.Image(457test_table.data[ndx][1],458masks={459"test_predicted_truth": {460"mask_data": results[ndx],461},462},463boxes={464"ground_truth": {"box_data": mask_to_bounding(results[ndx])}465},466),467BDD_CLASSES[get_dominant_id_ndx(results[ndx])],468]469+ list(row)470for ndx, row in enumerate(scores)471]472
473# And log out the results.474results_artifact.add(475wandb.Table(476["id", "pred_mask_test", "dominant_pred_test"] + BDD_CLASSES,477data=data,478),479"test_iou_score_table",480)481run.log_artifact(results_artifact)482print("Step 4/5 Complete")483
484with wandb.init(project=WANDB_PROJECT, job_type="model_result_analysis") as run:485# Retrieve the original raw dataset486dataset_artifact = run.use_artifact("raw_data:latest")487data_table = dataset_artifact.get("raw_examples")488
489# Retrieve the train and test score tables490train_artifact = run.use_artifact("train_results:latest")491train_table = train_artifact.get("train_iou_score_table")492
493test_artifact = run.use_artifact("test_results:latest")494test_table = test_artifact.get("test_iou_score_table")495
496# Join the tables on ID column and log them as outputs.497train_results = wandb.JoinedTable(train_table, data_table, "id")498test_results = wandb.JoinedTable(test_table, data_table, "id")499artifact = wandb.Artifact("summary_results", "dataset")500artifact.add(train_results, "train_results")501artifact.add(test_results, "test_results")502run.log_artifact(artifact)503print("Step 5/5 Complete")504
505if WANDB_PROJECT_ENV is not None:506os.environ["WANDB_PROJECT"] = WANDB_PROJECT_ENV507
508if WANDB_SILENT_ENV is not None:509os.environ["WANDB_SILENT"] = WANDB_SILENT_ENV510finally:511cleanup()512
513
514if __name__ == "__main__":515main()516