gpt-neox
/
eval.py
77 строк · 2.6 Кб
1# Copyright (c) 2024, EleutherAI
2# This file is based on code by the authors denoted below and has been modified from its original version.
3#
4# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Evaluation tasks - modified from https://github.com/EleutherAI/lm-evaluation-harness"""
19import os20import sys21
22sys.path.append(23os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))24)
25from megatron.training import forward_step26from megatron.utils import setup_for_inference_or_eval, init_wandb27from megatron.logging import tb_wandb_log28from eval_tasks import run_eval_harness29from pprint import pprint30from datetime import datetime31import json32
33
34def main(input_args=None, overwrite_values=None):35model, neox_args = setup_for_inference_or_eval(36use_cache=False, input_args=input_args, overwrite_values=overwrite_values37)38results = run_eval_harness(39model,40forward_step,41neox_args,42eval_tasks=neox_args.eval_tasks,43bootstrap_iters=10000,44)45if neox_args.rank == 0:46init_wandb(neox_args=neox_args)47# log to wandb48for k, v in results["results"].items():49if isinstance(v, dict):50for k2, v2 in v.items():51k3 = "_".join([k, k2])52tb_wandb_log(53f"eval/{k3}",54v2,55neox_args.iteration,56use_wandb=neox_args.use_wandb,57)58else:59tb_wandb_log(60f"eval/{k}",61v,62neox_args.iteration,63use_wandb=neox_args.use_wandb,64)65
66pprint(results)67results_path = (68f'eval_results_{datetime.now().strftime("%m-%d-%Y-%H-%M-%S")}.json'69)70if neox_args.eval_results_prefix:71results_path = f"{neox_args.eval_results_prefix}_{results_path}"72with open(results_path, "w") as f:73json.dump(results, f, indent=4)74
75
76if __name__ == "__main__":77main()78