math

Форк
0
/
eval_math_gpt.py 
372 строки · 13.2 Кб
1

2
"""
3

4
Example:
5

6
CUDA_VISIBLE_DEVICES=6 python3 eval_math_gpt.py \
7
    --arch=gpt2 \
8
    --math-dataroot=./MATH/test/*/*.json \
9
    --load=/data/sauravkadavath/maths-beta__modeling__checkpoints/MATH__bbox_only_3_epochs__finetune_6_epochs__pretraining_khan_latex_loss_only__gpt117/checkpoint.pth
10

11
"""
12

13
import io
14
import logging
15
import math
16
import os
17
import pprint
18
import sys
19
import json
20
import time
21
import transformers
22
import numpy as np
23

24
from tqdm import tqdm
25

26
import torch
27
import torch.distributed as dist
28
import torch.nn as nn
29
import torch.nn.functional as F
30
import torch.optim as optim
31
import torch.multiprocessing as mp
32

33
from torch.nn.parallel import DistributedDataParallel as DDP
34

35
from dataset.MATH import MATHDataset
36
from dataset.khan_academy import KhanAcademyMathDataset
37
from dataset.util import clean_numbers, last_boxed_only, last_boxed_only_string
38
from math_equivalence import is_equiv
39

40
def get_level_type(fname):
41
    """
42
    Somewhat inefficient, but much easier than changing dataloader and probably fine for evaluation
43
    """
44
    with open(fname, 'r') as fp:
45
        try:
46
            problem_data = json.load(fp)
47
        except Exception as e:
48
            print(f"Error loading JSON from {fname}", e)
49
            raise e
50
    level, prob_type = problem_data['level'], problem_data['type']
51
    try:
52
        level = int(level.split("Level ")[1])
53
    except:
54
        level = None
55
    return level, prob_type
56

57
def remove_boxed(s):
58
    left = "\\boxed{"
59
    try:
60
        assert s[:len(left)] == left
61
        assert s[-1] == "}"
62
        return s[len(left):-1]
63
    except:
64
        return None
65

66

67
def dict_to_gpu(d, device_id=None):
68
    new_dict = dict()
69
    for key, value in d.items():
70
        # Only move to GPU is cuda() is a function
71
        if 'cuda' in dir(value):
72
            new_dict[key] = value.cuda(device_id)
73
        else:
74
            new_dict[key] = value
75
    return new_dict
76

77

78
def get_real_sol_idxs(tokens_sol, tokenizer):
79
    """
80
    Return the start and stop indexes (inclusive) for everything inside \\boxed{...}
81
    """
82
    left_idx, right_idx = None, None
83
    for i in range(tokens_sol.shape[1]):
84
        if i < 3:
85
            continue
86

87
        if tokens_sol[0, i].item() and \
88
            tokens_sol[0, i-1].item() == 276 and \
89
            tokens_sol[0, i-2].item() == 3524:
90
            # at index i, we have the { of \\boxed{ 
91
            left_idx = i + 1 # Don't include the {
92
        
93
        if tokens_sol[0, i].item() == 50256:
94
            right_idx = i-2 # don't include the one token before the current one as well (usually the } from \boxed{})
95
    
96
    # Will error if either is not found, which we dont expect
97
    return left_idx, right_idx
98

99

100
def run_eval(args):
101

102
    argsdict = vars(args)
103
    print(pprint.pformat(argsdict))
104

105
    if args.tokenizer_merges_file is not None:
106
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(args.arch, merges_file=args.tokenizer_merges_file)
107
    else:
108
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(args.arch)
109

110
    eval_data = get_dataset(args)
111
    for inner_dset in eval_data.datasets:
112
        inner_dset.tokenizer = tokenizer
113

114
    dataloader = torch.utils.data.DataLoader(
115
        eval_data, 
116
        batch_size=1, 
117
        num_workers=0, 
118
        pin_memory=True,
119
    )
120

121
    """
122
    with torch.no_grad():
123
        correct = 0
124
        total = 0
125
        for i, batch in enumerate(tqdm(dataloader)):
126
            batch = dict_to_gpu(batch, device_id=0)
127
            print(batch['fnames'])
128
            print(batch['input_ids'])
129
            quit()
130
    """
131

132
    # Set up model
133
    if args.load is None:
134
        model = transformers.GPT2LMHeadModel.from_pretrained(args.arch)
135
    else:
136
        print(f"Loading model from {args.load}")
137
        model = transformers.GPT2LMHeadModel.from_pretrained(args.load)
138
        print(f"Successfully loaded model from {args.load}")
139

140
    model = model.eval()
141
    model = model.cuda()
142

143
    loss_moving_average = 0
144

145
    outputs = []
146
    answers = []
147
    types = []
148
    levels = []
149
    fnames_list = []
150

151
    cors = {}
152
    subject_cors = {}
153
    level_cors = {}
154

155
    with torch.no_grad():
156
        correct = 0
157
        total = 0
158
        skipped = 0
159
        mean_max_probs_correct = []
160
        mean_max_probs_wrong   = []
161
        for i, batch in enumerate(tqdm(dataloader)):
162

163
            if torch.sum(batch['input_ids']) == 0:
164
                skipped += 1
165
                print("SKIPPING", batch['fnames'][0])
166
                continue
167

168
            fnames = batch['fnames'][0]
169
            assert len(fnames) == 1
170
            fnames_list.append(fnames[0])
171
            prob_level, prob_type = get_level_type(fnames[0])
172
            batch = dict_to_gpu(batch, device_id=0)
173

174
            output_ids = model.generate(
175
                batch['input_ids'], 
176
                num_beams=args.num_beams, 
177
                early_stopping=True,
178
                temperature=1.0,
179
                max_length=384 if args.arch == 'gpt2-xl' else 1024
180
            )
181
            
182
            # logits = model(output_ids).logits
183
            # probs = F.softmax(logits, dim=2) # torch.Size([1, L, 50257])
184
            # max_probs, max_tokens = probs.max(2) # torch.Size([1, L]), torch.Size([1, L])
185

186
            # num_tokens_for_question = batch['input_ids'].shape[1]
187
            # probs_sol = max_probs[:, num_tokens_for_question-1:]
188
            # tokens_sol = max_tokens[:, num_tokens_for_question-1:]
189

190
            # real_sol_start_idx, real_sol_stop_idx = get_real_sol_idxs(tokens_sol, tokenizer)
191
            # if real_sol_start_idx is None or real_sol_stop_idx is None:
192
            #     skipped += 1
193
            #     print("BAD ANSWER, SKIPPING", batch['fnames'][0])
194
            #     continue
195
            # probs_sol = probs_sol[:, real_sol_start_idx:real_sol_stop_idx + 1]
196
            # mean_probs_sol = torch.mean(probs_sol).item()
197
            mean_probs_sol = 0
198

199
            output_tokens = get_model_output(batch['input_ids'][0], output_ids[0], tokenizer)
200

201
            # Print this iteration
202
            output_str = tokenizer.decode(output_tokens)
203
            output_full = output_str
204
            output_str = last_boxed_only_string(output_str)
205

206
            if args.math_mode == "eval_peeking":
207
                answer_str = last_boxed_only_string(tokenizer.decode(batch['labels'][0]))
208
            else:
209
                answer_str = tokenizer.decode(batch['labels'][0])
210

211
            output, answer = remove_boxed(output_str), remove_boxed(answer_str)
212

213
            print("Problem String:")
214
            print(tokenizer.decode(batch['input_ids'][0]) + "\n")
215
            print("Model output:")
216
            print(output_full)
217
            print(output)
218
            print("Correct answer:")
219
            print(answer)
220
            print("fname")
221
            print(fnames)
222
            print("--------------------------------------------")
223

224
            # scratchwork_fname = "___".join(fnames[0].split("/")[-2:])
225
            # with open(f"scratchwork_Temp2e-1_{args.arch}/{scratchwork_fname}.txt", 'w') as f:
226
            #     f.write("Problem String:" + "\n")
227
            #     f.write(tokenizer.decode(batch['input_ids'][0]) + "\n")
228
            #     f.write("Model output:" + "\n")
229
            #     f.write(output_full + "\n")
230
            #     f.write(str(output) + "\n")
231
            #     f.write("Correct answer:" + "\n")
232
            #     f.write(answer + "\n")
233
            #     f.write("--------------------------------------------" + "\n")
234

235
            outputs.append(output)
236
            answers.append(answer)
237
            types.append(prob_type)
238
            levels.append(prob_level)
239

240
            equiv = is_equiv(output, answer)
241
            if (prob_level, prob_type) in cors:
242
                cors[(prob_level, prob_type)].append(equiv)
243
            else:
244
                cors[(prob_level, prob_type)] = [equiv]
245
            
246
            if prob_level in level_cors:
247
                level_cors[prob_level].append(equiv)
248
            else:
249
                if prob_level is not None:
250
                    level_cors[prob_level] = [equiv]
251
            
252
            if prob_type in subject_cors:
253
                subject_cors[prob_type].append(equiv)
254
            else:
255
                if prob_type is not None:
256
                    subject_cors[prob_type] = [equiv]
257
            
258
            if equiv:
259
                correct += 1
260
                mean_max_probs_correct.append(mean_probs_sol)
261
            else:
262
                mean_max_probs_wrong.append(mean_probs_sol)
263

264
            # print("CORRECT", mean_max_probs_correct)
265
            # print("WRONG", mean_max_probs_wrong)
266
            
267
            total += 1
268

269
    subjects = ['Prealgebra', 'Algebra', 'Number Theory', 'Counting & Probability', 'Geometry', 'Intermediate Algebra', 'Precalculus']
270

271
    print(f"Average of mean_max_probs_correct = {sum(mean_max_probs_correct)}/{len(mean_max_probs_correct)} = ", sum(mean_max_probs_correct)/len(mean_max_probs_correct))
272
    print(f"Average of mean_max_probs_wrong   = {sum(mean_max_probs_wrong)}/{len(mean_max_probs_wrong)} = ", sum(mean_max_probs_wrong)/len(mean_max_probs_wrong))
273

274
    # now save outputs and answers
275
    with open(f"outputs_answers_Temp2e-1_{args.arch}.txt", "w+") as f:
276
        for k, (output, answer, prob_type, prob_level, fname) in enumerate(zip(outputs, answers, types, levels, fnames_list)):
277
            f.write("{} TYPE: {} | LEVEL: {} | OUTPUT: {} | ANSWER: {} | FNAME: {}\n".format(k, prob_type, prob_level, output, answer, fname))
278

279
        # print(cors)
280
        for prob_type in subjects:
281
            for prob_level in [1, 2, 3, 4, 5]:
282
                if (prob_level, prob_type) in cors:
283
                    cors_list = cors[(prob_level, prob_type)]
284
                    print("{} Level {} Accuracy = {}/{} = {:.3f}".format(prob_type, prob_level, np.sum(cors_list), len(cors_list), np.mean(cors_list)))
285
                    f.write("{} Level {} Accuracy = {}/{} = {:.3f}\n".format(prob_type, prob_level, np.sum(cors_list), len(cors_list), np.mean(cors_list)))
286

287
        print("#####################")
288
        f.write("#####################\n")
289
        # also get accuracies for each 
290
        for level in sorted(level_cors):
291
            cors_list = level_cors[level]
292
            print("Level {} Accuracy = {}/{} = {:.3f}".format(level, np.sum(cors_list), len(cors_list), np.mean(cors_list)))
293
            f.write("Level {} Accuracy = {}/{} = {:.3f}\n".format(level, np.sum(cors_list), len(cors_list), np.mean(cors_list)))
294
        print("#####################")
295
        f.write("#####################\n")
296

297
        for subject in subjects:
298
            # for subject in sorted(subject_cors):
299
            if subject in subject_cors:
300
                cors_list = subject_cors[subject]
301
                print("{} Accuracy = {}/{} = {:.3f}".format(subject, np.sum(cors_list), len(cors_list), np.mean(cors_list)))
302
                f.write("{} Accuracy = {}/{} = {:.3f}\n".format(subject, np.sum(cors_list), len(cors_list), np.mean(cors_list)))
303
        print("#####################")
304
        f.write("#####################\n")
305
        
306
        print("Overall Accuracy = {}/{} = {:.3f}".format(correct, total, correct/total))
307
        print("Skipped = {}".format(skipped))
308
        f.write("Overall Accuracy = {}/{} = {:.3f}\n".format(correct, total, correct/total))
309
        f.write("Skipped = {}".format(skipped))
310
    
311
    print()
312
    
313
def get_model_output(context, full_output, tokenizer):
314
    """
315
    Given the context and the full model output (context + generated),
316
    extract just the generated tokens.
317
    Remove the last token if it is <|endoftext|>
318
    """
319
    ret = full_output[len(context):]
320
    if ret[-1] == tokenizer.eos_token_id:
321
        ret = ret[:-1]
322
    return ret
323

324
def get_dataset(args):
325
    all_datasets = []
326

327
    if args.math_dataroot is not None:
328
        if args.math_mode == 'gpt2-eval':
329
            all_datasets.append(
330
                MATHDataset(
331
                    dataroot=args.math_dataroot, 
332
                    tokenizer=None, # Set in run_training(), not in dataset creation 
333
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024, 
334
                    mode='gpt2-eval', 
335
                )
336
            )
337
        else:
338
            all_datasets.append(
339
                MATHDataset(
340
                    dataroot=args.math_dataroot, 
341
                    tokenizer=None, # Set in run_training(), not in dataset creation 
342
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024, 
343
                    mode='gpt2-eval',
344
                    mode_answer=args.math_mode,
345
                    peek_fraction=args.peek_fraction
346
                )
347
            )
348

349
    
350
    train_data = torch.utils.data.ConcatDataset(all_datasets)
351
    return train_data
352

353

354
if __name__ == "__main__":
355
    import argparse
356

357
    parser = argparse.ArgumentParser(description="Language Modelling on Code")
358
    parser.add_argument('--arch', default='gpt2', choices=transformers.GPT2_PRETRAINED_MODEL_ARCHIVE_LIST)
359
    parser.add_argument('--load', default=None, type=str)
360
    parser.add_argument('--num-beams', default=20, type=int)
361
    parser.add_argument('--tokenizer-merges-file', default=None, type=str)
362

363
    # Dataloading
364
    parser.add_argument('--math-dataroot', default=None, type=str)
365
    parser.add_argument('--math-mode', default='gpt2-eval', type=str)
366
    parser.add_argument('--peek-fraction', type=float, default=1.0)
367
    
368
    # Others
369
    parser.add_argument('--workers', default=4, type=int)
370

371
    args = parser.parse_args()
372

373
    run_eval(args)
374

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.