Prompt-Transferability

Форк
0
/
copy_prompt_to_taskpromptdir.py 
391 строка · 12.1 Кб
1
import os
2
#import shutil
3
import shutil
4
import torch
5

6
all_model_prompt = os.listdir("model")
7

8
all_model_prompt = [dir for dir in all_model_prompt if ".py" not in dir]
9

10

11
for dataset_file in all_model_prompt:
12
    #if "T5" not in dataset_file or "Small" in dataset_file:
13
    #if "RobertaLarge" not in dataset_file or "Small" in dataset_file:
14
    #    continue
15
    #if "Small" not in dataset_file or "sam" not in dataset_file:
16
    if "T5Large" not in dataset_file:
17
        continue
18

19
    #if dataset_file != "QQPPromptRoberta":
20
    #    continue
21

22

23
    #if dataset_file != "ethicscommonsensePromptRoberta":
24
    #if dataset_file != "MRPCPromptRoberta":
25
    #    continue
26

27
    #print(file)
28

29
    original_dir = "model/"+str(dataset_file)
30
    if os.path.isdir(original_dir):
31
        pass
32
    else:
33
        continue
34

35
    check_list = [file for file in os.listdir(original_dir) if "_task_prompt" in file]
36
    if len(check_list) == 0:
37
        continue
38

39
    ##:mean do not use
40

41
    ##Choose epoch
42
    max_epoch = 0
43

44
    #Haven't done
45

46

47
    #tweet #training (68.XX)
48
    #ethicsdeontologyPromptT5 (63.8)
49
    #ethicsjusticePromptT5 (60.XX)
50
    #QQP (86.6)
51
    #squadPromptT5 (62.7)
52
    #nq_openPromptT5
53
    #multi_newsPromptT5
54
    #samsumPromptT5
55

56
    #MNLI #training
57
    #snli #training
58

59

60
    if dataset_file == "IMDBPromptRoberta":
61
        max_epoch = 23
62
    elif dataset_file == "IMDBPromptRobertaSmall":
63
        max_epoch = 29
64
    elif dataset_file == "IMDBPromptRobertaLarge":
65
        max_epoch = 27
66
    elif dataset_file == "IMDBPromptRoberta_label":
67
        max_epoch = 40
68
    elif dataset_file == "IMDBPromptBert":
69
        max_epoch = 21
70
    elif dataset_file == "IMDBPromptT5":
71
        max_epoch = 70
72
    elif dataset_file == "IMDBPromptT5Small":
73
        max_epoch = 69 #45
74
    elif dataset_file == "IMDBPromptT5Large":
75
        max_epoch = 60
76

77
    elif dataset_file == "SST2PromptRoberta":
78
        max_epoch = 25
79
    elif dataset_file == "SST2PromptRobertaSmall":
80
        max_epoch = 38
81
    elif dataset_file == "SST2PromptRobertaLarge":
82
        max_epoch = 26
83
    elif dataset_file == "SST2PromptRoberta_label":
84
        max_epoch = 18
85
    elif dataset_file == "SST2PromptBert":
86
        max_epoch = 18
87
    elif dataset_file == "SST2PromptT5":
88
        max_epoch = 26
89
    elif dataset_file == "SST2PromptT5Small":
90
        max_epoch = 17 #9
91
    elif dataset_file == "SST2PromptT5Large":
92
        max_epoch = 36
93

94
    elif dataset_file  == "laptopPromptRoberta":
95
        max_epoch = 32
96
    elif dataset_file  == "laptopPromptRobertaSmall":
97
        max_epoch = 40
98
    elif dataset_file  == "laptopPromptRobertaLarge":
99
        max_epoch = 93
100
    elif dataset_file  == "laptopPromptRoberta_label":
101
        max_epoch = 32
102
    elif dataset_file  == "laptopPromptBert":
103
        max_epoch = 30
104
    elif dataset_file  == "laptopPromptT5":
105
        max_epoch = 210
106
    elif dataset_file  == "laptopPromptT5Small":
107
        max_epoch = 229 #92
108
    elif dataset_file  == "laptopPromptT5Large":
109
        max_epoch = 499
110

111
    elif dataset_file == "restaurantPromptRoberta":
112
        max_epoch = 33
113
    elif dataset_file == "restaurantPromptRobertaSmall":
114
        max_epoch = 50
115
    elif dataset_file == "restaurantPromptRobertaLarge":
116
        max_epoch = 126
117
    elif dataset_file == "restaurantPromptRoberta_label":
118
        max_epoch = 32
119
    elif dataset_file == "restaurantPromptBert":
120
        max_epoch = 31
121
    elif dataset_file == "restaurantPromptT5":
122
        max_epoch = 276
123
    elif dataset_file == "restaurantPromptT5Small":
124
        max_epoch = 224 #162
125
    elif dataset_file == "restaurantPromptT5Large":
126
        max_epoch = 100
127

128
    elif dataset_file == "movierationalesPromptRoberta":
129
        max_epoch = 21
130
    elif dataset_file == "movierationalesPromptRobertaSmall":
131
        max_epoch = 31
132
    elif dataset_file == "movierationalesPromptRobertaLarge":
133
        max_epoch = 62
134
    elif dataset_file == "movierationalesPromptRoberta_label":
135
        max_epoch = 48
136
    elif dataset_file == "movierationalesPromptBert":
137
        max_epoch = 24
138
    elif dataset_file == "movierationalesPromptT5":
139
        max_epoch = 197
140
    elif dataset_file == "movierationalesPromptT5Small":
141
        max_epoch = 373 #299
142
    elif dataset_file == "movierationalesPromptT5Large":
143
        max_epoch = 100
144

145
    elif dataset_file == "tweetevalsentimentPromptRoberta":
146
        max_epoch = 28
147
    elif dataset_file == "tweetevalsentimentPromptRobertaSmall":
148
        max_epoch = 37
149
    elif dataset_file == "tweetevalsentimentPromptRobertaLarge":
150
        max_epoch = 54
151
    elif dataset_file == "tweetevalsentimentPromptRoberta_label":
152
        max_epoch = 23
153
    elif dataset_file == "tweetevalsentimentPromptBert":
154
        max_epoch = 21
155
    elif dataset_file == "tweetevalsentimentPromptT5":
156
        max_epoch = 18
157
    elif dataset_file == "tweetevalsentimentPromptT5Small":
158
        max_epoch = 32 #20 better
159
    elif dataset_file == "tweetevalsentimentPromptT5Large":
160
        max_epoch = 38
161

162

163
    elif dataset_file == "MNLIPromptRoberta":
164
        max_epoch = 44
165
    elif dataset_file == "MNLIPromptRobertaSmall":
166
        max_epoch = 13
167
    elif dataset_file == "MNLIPromptRobertaLarge":
168
        max_epoch = 10 ###
169
    elif dataset_file == "MNLIPromptRoberta_label":
170
        max_epoch = 30
171
    elif dataset_file == "MNLIPromptBert":
172
        max_epoch = 34
173
    elif dataset_file == "MNLIPromptT5":
174
        max_epoch = 5
175
    elif dataset_file == "MNLIPromptT5Small":
176
        max_epoch = 4 ##
177

178

179
    elif dataset_file == "QNLIPromptRoberta":
180
        max_epoch = 51
181
    elif dataset_file == "QNLIPromptRobertaSmall":
182
        max_epoch = 48
183
    elif dataset_file == "QNLIPromptRobertaLarge":
184
        max_epoch = 33 ###
185
    elif dataset_file == "QNLIPromptRoberta_label":
186
        max_epoch = 67
187
    elif dataset_file == "QNLIPromptBert":
188
        max_epoch = 41
189
    elif dataset_file == "QNLIPromptT5":
190
        max_epoch = 30
191
    elif dataset_file == "QNLIPromptT5Small":
192
        max_epoch = 11 ##
193

194
    elif dataset_file == "WNLIPromptRoberta":
195
        max_epoch = 755
196
    elif dataset_file == "WNLIPromptRoberta_label":
197
        max_epoch = 755
198
    elif dataset_file == "WNLIPromptBert":
199
        max_epoch = 754
200

201
    elif dataset_file == "snliPromptRoberta":
202
        max_epoch = 29
203
    elif dataset_file == "snliPromptRobertaSmall":
204
        max_epoch = 9
205
    elif dataset_file == "snliPromptRobertaLarge":
206
        max_epoch = 4 ###
207
    elif dataset_file == "snliPromptRoberta_label":
208
        max_epoch = 17
209
    elif dataset_file == "snliPromptBert":
210
        max_epoch = 32
211
    elif dataset_file == "snliPromptT5":
212
        max_epoch = 2
213
    elif dataset_file == "snliPromptT5Small":
214
        max_epoch = 2 #
215

216

217
    elif dataset_file =="RTEPromptRoberta":
218
        max_epoch = 250
219
    elif dataset_file =="RTEPromptRoberta_label":
220
        max_epoch = 250
221
    elif dataset_file =="RTEPromptBert":
222
        max_epoch = 249
223

224

225
    elif dataset_file == "QQPPromptRoberta":
226
         max_epoch =  22
227
    elif dataset_file == "QQPPromptRobertaSmall":
228
         max_epoch = 17
229
    elif dataset_file == "QQPPromptRobertaLarge":
230
         max_epoch = 8 ###
231
    elif dataset_file == "QQPPromptRoberta_label":
232
         max_epoch = 26
233
    elif dataset_file == "QQPPromptBert":
234
         max_epoch = 24
235
    elif dataset_file == "QQPPromptT5":
236
         max_epoch = 10
237
    elif dataset_file == "QQPPromptT5Small":
238
         max_epoch = 2 ###
239

240
    elif dataset_file == "MRPCPromptRoberta":
241
        max_epoch = 66
242
    elif dataset_file == "MRPCPromptRobertaSmall":
243
        max_epoch = 53
244
    elif dataset_file == "MRPCPromptRobertaLarge":
245
        max_epoch = 96
246
    elif dataset_file == "MRPCPromptRoberta_label":
247
        max_epoch = 30
248
    elif dataset_file == "MRPCPromptBert":
249
        max_epoch = 27
250
    elif dataset_file == "MRPCPromptT5":
251
        max_epoch = 199
252
    elif dataset_file == "MRPCPromptT5Small":
253
        max_epoch = 213
254

255

256
    elif dataset_file == "recastfactualityPromptRoberta":
257
        max_epoch = 21
258
    elif dataset_file == "recastfactualityPromptRoberta_label":
259
        max_epoch = 21
260
    elif dataset_file == "recastfactualityPromptBert":
261
        max_epoch = 20
262

263
    elif dataset_file == "recastpunsPromptRoberta":
264
        max_epoch = 36
265
    elif dataset_file == "recastpunsPromptRoberta_label":
266
        max_epoch = 36
267
    elif dataset_file == "recastpunsPromptBert":
268
        max_epoch = 35
269

270
    elif dataset_file == "recastverbcornerPromptRoberta":
271
        max_epoch = 35
272
    elif dataset_file == "recastverbcornerPromptRoberta_label":
273
        max_epoch = 35
274
    elif dataset_file == "recastverbcornerPromptBert":
275
        max_epoch = 34
276

277
    elif dataset_file == "recastnerPromptRoberta":
278
        max_epoch = 30
279
    elif dataset_file == "recastnerPromptRoberta_label":
280
        max_epoch = 18
281
    elif dataset_file == "recastnerPromptBert":
282
        max_epoch = 20
283

284
    elif dataset_file == "recastsentimentPromptRoberta":
285
        max_epoch = 58
286
    elif dataset_file == "recastsentimentPromptRoberta_label":
287
        max_epoch = 58
288
    elif dataset_file == "recastsentimentPromptBert":
289
        max_epoch = 57
290

291
    elif dataset_file == "recastmegaveridicalityPromptRoberta":
292
        max_epoch = 32
293
    elif dataset_file == "recastmegaveridicalityPromptRoberta_label":
294
        max_epoch = 32
295
    elif dataset_file == "recastmegaveridicalityPromptBert":
296
        max_epoch = 31
297

298
    elif dataset_file == "ethicscommonsensePromptRoberta":
299
        max_epoch = 96
300
    elif dataset_file == "ethicscommonsensePromptRoberta_label":
301
        max_epoch = 96
302
    elif dataset_file == "ethicscommonsensePromptBert":
303
        max_epoch = 95
304

305
    elif dataset_file == "ethicsdeontologyPromptRoberta":
306
        max_epoch = 63
307
    elif dataset_file == "ethicsdeontologyPromptRobertaSmall":
308
        max_epoch = 61
309
    elif dataset_file == "ethicsdeontologyPromptRobertaLarge":
310
        max_epoch = 79 #125 ###
311
    elif dataset_file == "ethicsdeontologyPromptRoberta_label":
312
        max_epoch = 77
313
    elif dataset_file == "ethicsdeontologyPromptBert":
314
        max_epoch = 14
315
    elif dataset_file == "ethicsdeontologyPromptT5":
316
        max_epoch = 101
317
    elif dataset_file == "ethicsdeontologyPromptT5Small":
318
        max_epoch = 52 #
319

320
    elif dataset_file == "ethicsjusticePromptRoberta":
321
        max_epoch = 29
322
    elif dataset_file == "ethicsjusticePromptRobertaSmall":
323
        max_epoch = 150
324
    elif dataset_file == "ethicsjusticePromptRobertaLarge":
325
        max_epoch = 68  #127 ###
326
    elif dataset_file == "ethicsjusticePromptRoberta_label":
327
        max_epoch = 63
328
    elif dataset_file == "ethicsjusticePromptBert":
329
        max_epoch = 15
330
    elif dataset_file == "ethicsjusticePromptT5":
331
        max_epoch = 141
332
    elif dataset_file == "ethicsjusticePromptT5Small":
333
        max_epoch = 48
334
    ##elif dataset_file == "ethicsvirtuePromptRoberta":
335
    ##    max_epoch = 21
336

337

338
    elif dataset_file == "squadPromptT5":
339
        max_epoch = 23
340
    elif dataset_file == "squadPromptT5Small":
341
        max_epoch = 11 ##
342
    elif dataset_file == "nq_openPromptT5":
343
        max_epoch = 15
344
    elif dataset_file == "nq_openPromptT5Small":
345
        max_epoch = 11 ##
346
    elif dataset_file == "multi_newsPromptT5":
347
        max_epoch = 21
348
    elif dataset_file == "multi_newsPromptT5Small":
349
        max_epoch = 16 ##
350
    elif dataset_file == "samsumPromptT5":
351
        max_epoch = 85
352
    elif dataset_file == "samsumPromptT5Small":
353
        max_epoch = 28 ###
354

355

356
    else:
357
        print("--------------------")
358
        print("Did not need to genertate this promt_emb:", dataset_file)
359
        print("--------------------")
360
        continue
361
        '''
362
        for file in os.listdir(original_dir):
363
            present_epoch = int(file.strip().split("_")[0])
364
            if present_epoch > max_epoch:
365
                max_epoch = present_epoch
366
        '''
367

368
    original_dir = original_dir+"/"+str(max_epoch)+"_task_prompt.pkl"
369

370

371

372
    try:
373
        parameters = torch.load(original_dir, map_location=lambda storage, loc: storage)
374
        prompt_emb = parameters["model"]
375
    except:
376
        print(dataset_file,"has no trained task_prompt.pkl at epoch",max_epoch)
377
        continue
378

379

380
    target_dir = "task_prompt_emb"+"/"+str(dataset_file)
381
    if os.path.isdir(target_dir):
382
        pass
383
    else:
384
        os.mkdir(target_dir)
385

386

387
    target_dir = target_dir+"/"+"task_prompt"
388

389
    torch.save(prompt_emb, target_dir)
390

391
    print("Save:", target_dir, " Done")
392

393

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.