CSS-LM
53 строки · 1.4 Кб
1import json
2import sys
3import random
4
5file = sys.argv[1]
6fewshot_n = int(sys.argv[2])
7with open(file) as f:
8data = json.load(f)
9
10label_list = list()
11for line in data:
12label_list.append(line["sentiment"])
13
14total_len = len(label_list)
15label_list = list(set(label_list))
16
17label_dict = dict()
18for line in data:
19try:
20label_dict[line["sentiment"]].append(line)
21except:
22label_dict[line["sentiment"]] = []
23label_dict[line["sentiment"]].append(line)
24
25print("=========")
26print("Number:",fewshot_n*len(label_list))
27#print("Sample N:",fewshot_n)
28#print("Total len",total_len)
29print("=========")
30
31train_n = list()
32for label in label_list:
33print("+++",len(label_dict[label]))
34#train_n += random.choices(label_dict[label],k=fewshot_n)
35#samples = random.sample(label_dict[label],min(len(label_dict[label]),fewshot_n))
36
37
38#ratio_num =round( fewshot_n*(len(label_dict[label])/total_len))+1
39#print("==",ratio_num)
40samples = random.sample(label_dict[label],min(len(label_dict[label]),fewshot_n))
41if len(samples) < fewshot_n:
42samples += random.choices(label_dict[label],k=fewshot_n-len(samples))
43train_n += samples
44#print(train_n)
45print(label,len(samples))
46print("--------")
47
48print("=========")
49print("Final Sample",len(train_n))
50
51#with open("train_"+str(fewshot_n)+".json", 'w') as f:
52with open("train.json_"+str(fewshot_n), 'w') as f:
53json.dump(train_n, f)
54