google-research
107 строк · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Analyze the experimental results from the logs."""
17
18import glob19
20from absl import app21from absl import flags22import matplotlib.pyplot as plt23import numpy as np24
25FLAGS = flags.FLAGS26
27flags.DEFINE_string("dataset_name", "m5",28"Dataset to analyze the completed experiments for.")29flags.DEFINE_integer("minimum_model_count", 10,30"Minimum model count for an experiment to visualize.")31
32
33def scrape_data_from_logs(log_file):34"""Scrapes the validation and test metrics data from the logs."""35
36val_metrics = []37test_metrics = []38hyperparameters = []39
40with open(log_file, "r") as myfile:41lines = myfile.read().split("\n")42for ind, line in enumerate(lines):43if line.startswith("Hyperparameters"):44line_comma_sep = lines[ind + 2].split(",")45val_metric = float(line_comma_sep[-2].strip(" "))46test_metric = float(line_comma_sep[-1].strip(" ").strip("]"))47val_metrics.append(val_metric)48test_metrics.append(test_metric)49hyperparameters.append(lines[ind + 1])50
51val_metrics = np.asarray(val_metrics)52test_metrics = np.asarray(test_metrics)53
54return val_metrics, test_metrics, hyperparameters55
56
57def display_metrics(val_metrics, test_metrics, title, performance_threshold,58filename):59"""Displays the metrics of the trained models so far."""60
61# Remove the outliers62val_metrics = np.asarray(val_metrics)63test_metrics = np.asarray(test_metrics)64test_metrics = test_metrics[val_metrics < performance_threshold]65val_metrics = val_metrics[val_metrics < performance_threshold]66
67if val_metrics.size > 0:68plt.figure()69plt.plot(val_metrics, test_metrics, "o")70if val_metrics.size > 2:71m, b = np.polyfit(val_metrics, test_metrics, 1)72plt.plot(val_metrics, m * val_metrics + b, "k--")73v_min = np.min([np.min(val_metrics), np.min(test_metrics)]) * 0.874v_max = np.max([np.max(val_metrics), np.max(test_metrics)]) * 1.075plt.xlabel("Validation")76plt.ylabel("Test")77plt.title(title)78plt.xlim([v_min, v_max])79plt.ylim([v_min, v_max])80plt.gca().set_aspect("equal", adjustable="box")81plt.savefig(filename)82
83
84def main(args):85"""Main function to iterate over the experiments."""86
87del args # Not used.88
89log_files = glob.glob("./logs/experiment_" + str(FLAGS.dataset_name) +90"*.log")91
92for log_file in log_files:93
94experiment_name = log_file.split("/")[-1]95experiment_name = experiment_name.split(".")[0]96
97val_metrics, test_metrics, hyperparameters = scrape_data_from_logs(log_file)98
99if len(val_metrics) > FLAGS.minimum_model_count:100print("------------------------------------")101print("Experiment name:")102print(experiment_name)103display_metrics(val_metrics, test_metrics, hyperparameters, 1000, "")104
105
106if __name__ == "__main__":107app.run(main)108