OpenBackdoor

Форк
0
218 строк · 8.2 Кб
1
import logging
2
import numpy as np
3
import pickle
4
import random
5

6
MAX_ROBERTA_LENGTH = 502
7

8
random.seed(12)
9
logger = logging.getLogger(__name__)
10

11

12
class Instance(object):
13
    def __init__(self, args, config, instance_dict):
14
        self.dict = instance_dict
15
        self.args = args
16
        self.config = config
17
        self.truncated = False
18
        self.sent1_tokens = np.array(instance_dict["sent1_tokens"], dtype=np.int32)
19
        self.sent2_tokens = np.array(instance_dict["sent2_tokens"], dtype=np.int32)
20
        self.init_context_size = config["max_prefix_length"] + 1
21

22
    def preprocess(self, tokenizer):
23
        # shorten the very long sequences in the instance based on DATASET_CONFIG
24
        self.truncate()
25
        # whenever args.prefix_input_type has "original_shuffle" or "original_reverse"
26
        # exchange prefix/suffix with 50% probability or 100% probability
27
        self.shuffle_prefix_suffix()
28
        # Finally, perform prefix and suffix padding to build the sentence, label and segments
29
        self.build_sentence(tokenizer)
30
        self.build_label(tokenizer)
31
        self.build_segment(tokenizer)
32
        # check if the padding worked out correctly and all the lengths are aligned
33
        self.check_constraints()
34

35
    def truncate(self):
36
        config = self.config
37
        max_prefix_length = config["max_prefix_length"]
38
        max_suffix_length = config["max_suffix_length"]
39
        if len(self.sent1_tokens) > max_prefix_length:
40
            self.truncated = True
41
            self.sent1_tokens = self.sent1_tokens[:max_prefix_length]
42
        if len(self.sent2_tokens) > max_suffix_length:
43
            self.truncated = True
44
            self.sent2_tokens = self.sent2_tokens[:max_suffix_length]
45

46
    def shuffle_prefix_suffix(self):
47
        if not hasattr(self.args, "prefix_input_type"):
48
            # Keeping this check for backward compatibility with previous models
49
            return
50
        if self.args.prefix_input_type == "original_shuffle":
51
            # shuffle with 50% probability
52
            if random.random() <= 0.5:
53
                self.sent1_tokens, self.sent2_tokens = self.sent2_tokens, self.sent1_tokens
54

55
        elif self.args.prefix_input_type == "original_reverse":
56
            self.sent1_tokens, self.sent2_tokens = self.sent2_tokens, self.sent1_tokens
57

58
    def build_sentence(self, tokenizer):
59
        self.sent_prefix = left_padding(
60
            self.sent1_tokens, tokenizer.pad_token_id, self.config["max_prefix_length"]
61
        )
62

63
        self.sent_suffix = right_padding(
64
            np.append(self.sent2_tokens, tokenizer.eos_token_id),
65
            tokenizer.pad_token_id,
66
            self.config["max_suffix_length"] + 1
67
        )
68
        self.sentence = np.concatenate(
69
            [self.sent_prefix, [tokenizer.bos_token_id], self.sent_suffix]
70
        )
71

72
    def build_label(self, tokenizer):
73
        dense_length = self.config["global_dense_length"]
74
        self.label_suffix = right_padding(
75
            np.append(self.sent2_tokens, tokenizer.eos_token_id),
76
            -100,
77
            self.config["max_suffix_length"] + 1
78
        )
79
        self.label = np.concatenate([
80
            [-100 for _ in range(dense_length)],
81
            [-100 for _ in self.sent_prefix],
82
            [-100],
83
            self.label_suffix
84
        ]).astype(np.int64)
85

86
    def build_segment(self, tokenizer):
87
        dense_length = self.config["global_dense_length"]
88
        prefix_segment = [tokenizer.additional_special_tokens_ids[1] for _ in self.sent_prefix]
89
        suffix_segment_tag = tokenizer.additional_special_tokens_ids[2]
90

91
        self.segment = np.concatenate([
92
            [tokenizer.additional_special_tokens_ids[0] for _ in range(dense_length)],
93
            prefix_segment,
94
            [suffix_segment_tag],
95
            [suffix_segment_tag for _ in self.sent_suffix],
96
        ]).astype(np.int64)
97

98
    def check_constraints(self):
99
        dense_length = self.config["global_dense_length"]
100
        assert len(self.sentence) == len(self.label) - dense_length
101
        assert len(self.sentence) == len(self.segment) - dense_length
102

103

104
class InverseInstance(Instance):
105
    def __init__(self, args, config, instance_dict):
106
        self.dict = instance_dict
107
        self.args = args
108
        self.config = config
109
        self.truncated = False
110
        self.init_context_size = config["max_prefix_length"] + 1
111

112
        self.original_sentence = instance_dict["sentence"]
113
        self.prefix_sentence = instance_dict["prefix_sentence"]
114
        self.suffix_style = instance_dict["suffix_style"]
115
        self.original_style = instance_dict["original_style"]
116

117
        self.sent1_tokens = np.array(
118
            [int(x) for x in self.prefix_sentence.split()],
119
            dtype=np.int32
120
        )
121
        self.sent2_tokens = np.array(self.original_sentence, dtype=np.int32)
122

123

124
def np_prepend(array, value):
125
    return np.insert(array, 0, value)
126

127

128
def left_padding(data, pad_token, total_length):
129
    tokens_to_pad = total_length - len(data)
130
    return np.pad(data, (tokens_to_pad, 0), constant_values=pad_token)
131

132

133
def right_padding(data, pad_token, total_length):
134
    tokens_to_pad = total_length - len(data)
135
    return np.pad(data, (0, tokens_to_pad), constant_values=pad_token)
136

137

138
def string_to_ids(text, tokenizer):
139
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
140

141

142
def get_label_dict(data_dir):
143
    label_dict = {}
144
    with open("{}/dict.txt".format(data_dir)) as f:
145
        label_dict_lines = f.read().strip().split("\n")
146
    for i, x in enumerate(label_dict_lines):
147
        if x.startswith("madeupword"):
148
            continue
149
        label_dict[x.split()[0]] = i
150
    reverse_label_dict = {v: k for k, v in label_dict.items()}
151

152
    return label_dict, reverse_label_dict
153

154

155
def get_global_dense_features(data_dir, global_dense_feature_list, label_dict):
156
    """Get dense style code vectors for the style code model."""
157

158
    global_dense_features = []
159
    if global_dense_feature_list != "none":
160
        logger.info("Using global dense vector features = %s" % global_dense_feature_list)
161
        for gdf in global_dense_feature_list.split(","):
162
            with open("{}/{}_dense_vectors.pickle".format(data_dir, gdf), "rb") as f:
163
                vector_data = pickle.load(f)
164

165
            final_vectors = {}
166
            for k, v in vector_data.items():
167
                final_vectors[label_dict[k]] = v["sum"] / v["total"]
168

169
            global_dense_features.append((gdf, final_vectors))
170
    return global_dense_features
171

172

173
def limit_dataset_size(dataset, limit_examples):
174
    """Limit the dataset size to a small number for debugging / generation."""
175

176
    if limit_examples:
177
        logger.info("Limiting dataset to {:d} examples".format(limit_examples))
178
        dataset = dataset[:limit_examples]
179

180
    return dataset
181

182

183
def limit_styles(dataset, specific_style_train, split, reverse_label_dict):
184
    """Limit the dataset size to a certain author."""
185
    specific_style_train = [int(x) for x in specific_style_train.split(",")]
186

187
    original_dataset_size = len(dataset)
188
    if split in ["train", "test"] and -1 not in specific_style_train:
189
        logger.info("Preserving authors = {}".format(", ".join([reverse_label_dict[x] for x in specific_style_train])))
190
        dataset = [x for x in dataset if x["suffix_style"] in specific_style_train]
191
        logger.info("Remaining instances after author filtering = {:d} / {:d}".format(len(dataset), original_dataset_size))
192
    return dataset
193

194

195
def datum_to_dict(config, datum, tokenizer):
196
    """Convert a data point to the instance dictionary."""
197

198
    instance_dict = {"metadata": ""}
199

200
    for key in config["keys"]:
201
        element_value = datum[key["position"]]
202
        instance_dict[key["key"]] = string_to_ids(element_value, tokenizer) if key["tokenize"] else element_value
203
        if key["metadata"]:
204
            instance_dict["metadata"] += "%s = %s, " % (key["key"], str(element_value))
205
    # strip off trailing , from metadata
206
    instance_dict["metadata"] = instance_dict["metadata"][:-2]
207
    return instance_dict
208

209

210
def update_config(args, config):
211
    if args.global_dense_feature_list != "none":
212
        global_dense_length = len(args.global_dense_feature_list.split(","))
213
        logger.info("Using {:d} dense feature vectors.".format(global_dense_length))
214
    else:
215
        global_dense_length = 0
216

217
    assert global_dense_length <= config["max_dense_length"]
218
    config["global_dense_length"] = global_dense_length
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.