OpenBackdoor

Форк
0
132 строки · 5.3 Кб
1
import pickle
2
import torch
3

4
import numpy as np
5

6
from transformers import GPT2LMHeadModel, GPT2Tokenizer
7

8
from .dataset_config import BASE_CONFIG
9
from .data_utils import update_config, Instance, get_label_dict
10

11
from .utils import init_gpt2_model
12

13

14
class GPT2Generator(object):
15
    def __init__(self, model_path, upper_length="same_10", beam_size=1, top_p=0.0, data_dir=None):
16
        self.model_path = model_path
17
        self.args = torch.load("{}/training_args.bin".format(self.model_path))
18
        self.modify_args(upper_length, beam_size, top_p)
19
        self.config = BASE_CONFIG
20
        update_config(self.args, self.config)
21

22
        if self.args.global_dense_feature_list != "none":
23

24
            self.label_dict, self.reverse_label_dict = get_label_dict(data_dir)
25

26
            self.global_dense_features = []
27
            for gdf in self.args.global_dense_feature_list.split(","):
28
                with open(
29
                    "{}/{}_dense_vectors.pickle".format(data_dir, gdf), "rb"
30
                ) as f:
31
                    vector_data = pickle.load(f)
32

33
                final_vectors = {}
34
                for k, v in vector_data.items():
35
                    final_vectors[self.label_dict[k]] = v["sum"] / v["total"]
36

37
                self.global_dense_features.append((gdf, final_vectors))
38

39
        self.gpt2_model, self.tokenizer = init_gpt2_model(checkpoint_dir=model_path,
40
                                                          args=self.args,
41
                                                          model_class=GPT2LMHeadModel,
42
                                                          tokenizer_class=GPT2Tokenizer)
43

44
    def modify_args(self, upper_length, beam_size, top_p):
45
        args = self.args
46
        args.upper_length = upper_length
47
        args.stop_token = "eos" if upper_length == "eos" else None
48
        args.beam_size = beam_size
49
        args.num_samples = 1
50
        args.temperature = 0
51
        args.top_p = top_p
52
        args.top_k = 1
53
        if torch.cuda.is_available():
54
            args.device = torch.cuda.current_device()
55
        else:
56
            args.device = 'cpu'
57

58
    def modify_p(self, top_p):
59
        self.args.top_p = top_p
60

61
    def generate_batch(self, contexts, global_dense_features=None, get_scores=False,
62
                       interpolation=None, top_p=None):
63
        args = self.args
64
        tokenizer = self.tokenizer
65
        instances = []
66

67
        if global_dense_features is None:
68
            global_dense_features = [None for _ in contexts]
69

70
        for context, gdf in zip(contexts, global_dense_features):
71
            context_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(context))
72

73
            # NOTE - For model_110, use the older version of the code
74
            # The following code is only compatible with the newer models
75
            instance = Instance(
76
                self.args, self.config,
77
                {"sent1_tokens": context_ids, "sent2_tokens": context_ids}
78
            )
79
            instance.preprocess(tokenizer)
80

81
            if gdf is not None and self.args.global_dense_feature_list != "none":
82
                if self.global_dense_features:
83
                    global_dense_vectors = np.array(
84
                        [x[1][gdf] for x in self.global_dense_features],
85
                        dtype=np.float32,
86
                    )
87
                else:
88
                    global_dense_vectors = np.zeros((2, 20), dtype=np.float32)
89
                    global_dense_vectors[0, gdf["f1_bucket"]] = 1
90
                    global_dense_vectors[1, gdf["ed_bucket"] + 10] = 1
91
            else:
92
                global_dense_vectors = np.zeros((1, 768), dtype=np.float32)
93

94
            instance.gdv = global_dense_vectors
95
            instances.append(instance)
96

97
        output, _, scores = self.gpt2_model.generate(
98
            gpt2_sentences=torch.tensor([inst.sentence for inst in instances]).to(args.device),
99
            segments=torch.tensor([inst.segment for inst in instances]).to(args.device),
100
            global_dense_vectors=torch.tensor([inst.gdv for inst in instances]).to(args.device),
101
            init_context_size=instances[0].init_context_size,
102
            eos_token_id=tokenizer.eos_token_id,
103
            get_scores=get_scores,
104
            interpolation=interpolation,
105
            top_p=top_p
106
        )
107

108
        all_output = []
109
        for out_num in range(len(output)):
110
            instance = instances[out_num]
111
            curr_out = output[out_num, instance.init_context_size:].tolist()
112

113
            if tokenizer.eos_token_id in curr_out:
114
                curr_out = curr_out[:curr_out.index(tokenizer.eos_token_id)]
115

116
            if self.args.upper_length.startswith("same"):
117
                extra = int(self.args.upper_length.split("_")[-1])
118
                curr_out = curr_out[:len(instance.sent1_tokens) + extra]
119

120
            all_output.append(
121
                tokenizer.decode(curr_out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
122
            )
123

124
        return all_output, scores
125

126
    def generate(self, context, global_dense_features=None, get_scores=False,
127
                 interpolation=None, top_p=None):
128
        return self.generate_batch([context],
129
                                   [global_dense_features] if global_dense_features is not None else None,
130
                                   get_scores=get_scores,
131
                                   interpolation=interpolation,
132
                                   top_p=top_p)[0][0]
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.