OpenBackdoor

Форк
0
438 строк · 16.2 Кб
1
import logging
2
import pickle
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
from functools import partial
8

9

10
MODEL_CLASSES = {
11
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer)
12
}
13

14

15
logger = logging.getLogger(__name__)
16

17

18
def class_number_to_str(eval_dataset, class_number):
19
    if isinstance(class_number, str):
20
        return ", ".join(["{} {}".format(x.split("-")[0], x.split("-")[1]) for x in class_number.split("_")])
21
    else:
22
        return eval_dataset.reverse_label_dict[class_number.item()]
23

24
def recall(sentence, srl_string):
25
    matches = 0
26
    for word in sentence.split():
27
        if word in srl_string:
28
            matches += 1
29

30
    if len(sentence.split()) > 0:
31
        return float(matches) / len(sentence.split())
32
    else:
33
        return 0
34

35

36
def rindex(mylist, myvalue):
37
    return len(mylist) - mylist[::-1].index(myvalue) - 1
38

39

40
def init_gpt2_model(checkpoint_dir, args, model_class, tokenizer_class=None):
41
    """Load a trained model and vocabulary that you have fine-tuned."""
42

43
    model = model_class.from_pretrained(checkpoint_dir)
44
    model.to(args.device)
45

46
    if tokenizer_class:
47
        tokenizer = tokenizer_class.from_pretrained(checkpoint_dir, do_lower_case=args.do_lower_case)
48
    else:
49
        tokenizer = None
50

51
    return GPT2ParentModule(args=args, gpt2=model), tokenizer
52

53

54
class GPT2ParentModule(nn.Module):
55
    def __init__(self, args, gpt2):
56
        super(GPT2ParentModule, self).__init__()
57
        self.args = args
58
        self.gpt2 = gpt2
59

60
    def forward(self, batch):
61
        args = self.args
62
        gpt2 = self.gpt2
63

64
        sentences = batch["sentence"].to(args.device)
65
        labels = batch["label"].to(args.device)
66
        segments = batch["segment"].to(args.device)
67
        global_dense_vectors = batch["global_dense_vectors"].to(args.device)
68

69
        if args.global_dense_feature_list == "none":
70
            prefix_input_vectors = None
71
        else:
72
            prefix_input_vectors = global_dense_vectors
73

74
        gpt2.train()
75
        if prefix_input_vectors is None:
76
            outputs = gpt2(
77
                input_ids=sentences,
78
                token_type_ids=segments,
79
                labels=labels
80
            )
81
        else:
82
            outputs = gpt2(
83
                input_ids=sentences,
84
                token_type_ids=segments,
85
                labels=labels,
86
                prefix_input_vectors=prefix_input_vectors
87
            )
88

89
        loss = {
90
            "lm": outputs[0]
91
        }
92

93
        return loss
94

95
    def evaluate(self, batch):
96
        args = self.args
97
        gpt2 = self.gpt2
98

99
        sentences = batch["sentence"].to(args.device)
100
        labels = batch["label"].to(args.device)
101
        segments = batch["segment"].to(args.device)
102
        global_dense_vectors = batch["global_dense_vectors"].to(args.device)
103

104
        if args.global_dense_feature_list == "none":
105
            prefix_input_vectors = None
106
        else:
107
            prefix_input_vectors = global_dense_vectors
108

109
        with torch.no_grad():
110
            if prefix_input_vectors is None:
111
                outputs = gpt2(
112
                    input_ids=sentences,
113
                    token_type_ids=segments,
114
                    labels=labels
115
                )
116
            else:
117
                outputs = gpt2(
118
                    input_ids=sentences,
119
                    token_type_ids=segments,
120
                    labels=labels,
121
                    prefix_input_vectors=prefix_input_vectors
122
                )
123
            lm_loss = outputs[0]
124

125
        return lm_loss.mean().item()
126

127
    def generate(self, gpt2_sentences, segments, global_dense_vectors=None,
128
                 init_context_size=1, eos_token_id=None, get_scores=False,
129
                 interpolation=None, top_p=None):
130
        args = self.args
131
        gpt2 = self.gpt2
132

133
        if args.global_dense_feature_list == "none":
134
            style_content_vectors = None
135
        else:
136
            style_content_vectors = global_dense_vectors
137

138
        generation_length = None if self.args.stop_token == "eos" else len(gpt2_sentences[0]) - init_context_size
139
        dense_length = 0 if style_content_vectors is None else len(style_content_vectors[0])
140

141
        if args.beam_size > 1:
142
            out, scores = beam_search(
143
                model=gpt2,
144
                length=generation_length,
145
                context=gpt2_sentences[:, 0:init_context_size],
146
                style_content_vectors=style_content_vectors,  # mixed_style_content,
147
                segments=segments[:, 0:dense_length + init_context_size],
148
                eos_token_id=eos_token_id,
149
                beam_size=args.beam_size,
150
                beam_search_scoring=args.beam_search_scoring
151
            )
152
        else:
153
            out, scores = sample_sequence(
154
                model=gpt2,
155
                context=gpt2_sentences[:, 0:init_context_size],
156
                style_content_vectors=style_content_vectors,  # mixed_style_content,
157
                segments=segments[:, 0:dense_length + init_context_size],
158
                eos_token_id=eos_token_id,
159
                num_samples=args.num_samples,
160
                length=generation_length,
161
                temperature=args.temperature,
162
                top_k=args.top_k,
163
                top_p=top_p or args.top_p,
164
                get_scores=True,
165
                interpolation=interpolation
166
            )
167
        return out, dense_length, scores
168

169

170
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
171
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
172
        Args:
173
            logits: logits distribution shape (batch size x vocabulary size)
174
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
175
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
176
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
177
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
178
    """
179
    top_k = min(top_k, logits.size(-1))  # Safety check
180
    if top_p > 0.0:
181
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
182
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
183
        # Remove tokens with cumulative probability above the threshold
184
        sorted_indices_to_remove = cumulative_probs > top_p
185
        # Shift the indices to the right to keep also the first token above the threshold
186
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
187
        sorted_indices_to_remove[..., 0] = 0
188

189
        # scatter sorted tensors to original indexing
190
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
191
        logits[indices_to_remove] = filter_value
192

193
    elif top_k > 0:
194
        # Remove all tokens with a probability less than the last token of the top-k
195
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
196
        logits[indices_to_remove] = filter_value
197

198
    return logits
199

200

201
def get_logits(model, iteration, generated, segments, style_content_vectors, past):
202
    if iteration == 0:
203
        if style_content_vectors is None:
204
            pred = model(
205
                input_ids=generated,
206
                token_type_ids=segments,
207
                return_dict=True
208
            )
209
        else:
210
            pred = model(
211
                input_ids=generated,
212
                token_type_ids=segments,
213
                prefix_input_vectors=style_content_vectors,
214
                return_dict=True
215
            )
216
    else:
217
        # used the cached representations to speed up decoding
218
        pred = model(
219
            input_ids=generated[:, -1:],
220
            token_type_ids=segments[:, -1:],
221
            past_key_values=past,
222
            return_dict=True
223
        )
224
    logits = pred['logits']
225
    past = pred['past_key_values']
226
    return logits, past
227

228
def get_logits_old(model, iteration, generated, segments, style_content_vectors, past):
229
    if iteration == 0:
230
        if style_content_vectors is None:
231
            logits, past = model(
232
                input_ids=generated,
233
                token_type_ids=segments
234
            )
235
        else:
236
            logits, past = model(
237
                input_ids=generated,
238
                token_type_ids=segments,
239
                prefix_input_vectors=style_content_vectors
240
            )
241
    else:
242
        # used the cached representations to speed up decoding
243
        logits, past = model(
244
            input_ids=generated[:, -1:],
245
            token_type_ids=segments[:, -1:],
246
            past=past
247
        )
248
    return logits, past
249

250

251
def sample_sequence(model, length, context, style_content_vectors, segments, eos_token_id,
252
                    num_samples=1, temperature=1, top_k=0, top_p=0.0, get_scores=False,
253
                    interpolation=None):
254

255
    if length is None and style_content_vectors is not None:
256
        new_length = 1024 - style_content_vectors.shape[1] - context.shape[1]
257
    elif length is None and style_content_vectors is None:
258
        new_length = 1024 - context.shape[1]
259
    else:
260
        new_length = length
261

262
    batch_size = context.shape[0]
263

264
    eos_emitted = [False for _ in range(batch_size)]
265

266
    generated = context
267
    scores = [{"score": 0, "sequence": []} for _ in range(batch_size)]
268
    with torch.no_grad():
269
        past = None
270
        past2 = None
271
        for i in range(new_length):
272
            logits, past = get_logits(
273
                model, i, generated, segments, style_content_vectors, past
274
            )
275
            if interpolation:
276
                logits2, past2 = get_logits(
277
                    model=interpolation["model"].roberta_gpt2.gpt2,
278
                    iteration=i,
279
                    generated=generated,
280
                    segments=segments,
281
                    style_content_vectors=style_content_vectors,
282
                    past=past2
283
                )
284
                probs = F.softmax(logits[:, -1, :], dim=-1)
285
                probs2 = F.softmax(logits2[:, -1, :], dim=-1)
286
                final_probs = interpolation["weight"] * probs2 + (1 - interpolation["weight"]) * probs
287
                next_token_logits = torch.log(final_probs) / (temperature if temperature > 0 else 1.)
288
                original_scores = next_token_logits.clone()
289
            else:
290
                next_token_logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.)
291
                original_scores = F.log_softmax(next_token_logits, dim=-1)
292

293
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
294
            if temperature == 0 and top_k in [0, 1] and top_p == 0.0:
295
                # greedy sampling
296
                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
297
            else:
298
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
299

300
            if get_scores:
301
                for batch_elem in range(batch_size):
302
                    if eos_emitted[batch_elem]:
303
                        continue
304
                    scores[batch_elem]["score"] += original_scores[batch_elem, next_token[batch_elem].item()].item()
305
                    scores[batch_elem]["sequence"].append("token")
306

307
            generated = torch.cat((generated, next_token), dim=1)
308
            segments = torch.cat((segments, segments[:, -1:]), dim=1)
309

310
            for batch_elem in range(batch_size):
311
                if next_token[batch_elem].item() == eos_token_id:
312
                    eos_emitted[batch_elem] = True
313

314
            if length is None and all(eos_emitted):
315
                break
316

317
    if get_scores:
318
        scores = [score_fn(x, True) for x in scores]
319

320
    return generated, scores
321

322

323
def score_fn(x, length_normalize):
324
    if length_normalize:
325
        return x["score"] / len(x["sequence"])
326
    else:
327
        return x["score"]
328

329

330
def beam_search(model, length, context, style_content_vectors, segments, eos_token_id,
331
                beam_size=1, beam_search_scoring="normalize"):
332

333
    def merge_pasts(all_beams, prev_past):
334
        past_indices = [beam["past"] for element in all_beams for beam in element]
335
        return [pp[:, past_indices, :, :, :] for pp in prev_past]
336

337
    def merge_input_ids(all_beams):
338
        input_ids = [beam["sequence"][-1] for element in all_beams for beam in element]
339
        return torch.cat(input_ids, dim=0)
340

341
    if beam_search_scoring == "normalize":
342
        _score_fn = partial(score_fn, length_normalize=True)
343
    else:
344
        _score_fn = partial(score_fn, length_normalize=False)
345

346
    if length is None and style_content_vectors is not None:
347
        new_length = 1024 - style_content_vectors.shape[1] - context.shape[1]
348
    elif length is None and style_content_vectors is None:
349
        new_length = 1024 - context.shape[1]
350
    else:
351
        new_length = length
352

353
    with torch.no_grad():
354
        if style_content_vectors is None:
355
            logits, past = model(
356
                input_ids=context,
357
                token_type_ids=segments
358
            )
359
        else:
360
            logits, past = model(
361
                input_ids=context,
362
                token_type_ids=segments,
363
                prefix_input_vectors=style_content_vectors
364
            )
365
        log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
366
        top_scores, top_indices = torch.topk(input=log_probs, k=beam_size, dim=-1)
367

368
        all_beams = []
369
        for elem_num, (ts, ti) in enumerate(zip(top_scores, top_indices)):
370
            curr_element = []
371
            for bs in range(beam_size):
372
                curr_element.append({
373
                    "score": ts[bs],
374
                    "past": elem_num,
375
                    "sequence": [x.unsqueeze(0).unsqueeze(0) for x in context[elem_num]] + [ti[bs].unsqueeze(0).unsqueeze(0)],
376
                    "eos_emitted": False
377
                })
378
            all_beams.append(curr_element)
379

380
        # one time step here since segment IDs remain constant during generation
381
        tiled_segments = torch.cat([segments[:, -1:] for _ in range(beam_size)], dim=-1)
382

383
        for i in range(1, new_length):
384
            # check if all beams have emitted an EOS token
385
            all_eos = all([beam["eos_emitted"] for element in all_beams for beam in element])
386
            if all_eos:
387
                break
388

389
            latest_input_ids = merge_input_ids(all_beams)
390
            past = merge_pasts(all_beams, past)
391

392
            logits, past = model(
393
                input_ids=latest_input_ids,  # input_ids[:, -1:],
394
                token_type_ids=tiled_segments,
395
                past=past
396
            )
397
            log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
398
            top_scores, top_indices = torch.topk(input=log_probs, k=beam_size, dim=-1)
399

400
            new_beams = []
401
            curr_element = []
402
            for mb_num, (ts, ti) in enumerate(zip(top_scores, top_indices)):
403
                current_elem_num = mb_num // beam_size
404
                current_elem_beam_num = mb_num % beam_size
405
                old_beam = all_beams[current_elem_num][current_elem_beam_num]
406

407
                if old_beam["eos_emitted"]:
408
                    curr_element.append(old_beam)
409
                else:
410
                    for bs in range(beam_size):
411
                        token = ti[bs].unsqueeze(0).unsqueeze(0)
412
                        curr_element.append({
413
                            "score": old_beam["score"] + ts[bs],
414
                            "past": mb_num,
415
                            "sequence": old_beam["sequence"] + [token],
416
                            "eos_emitted": token.item() == eos_token_id
417
                        })
418
                if current_elem_beam_num == beam_size - 1:
419
                    new_beams.append(curr_element)
420
                    curr_element = []
421

422
            # Sort the beams by score and keep only top scoring elements
423
            all_beams = []
424
            for elem in new_beams:
425
                elem.sort(key=lambda x: _score_fn(x), reverse=True)
426
                all_beams.append(elem[:beam_size])
427

428
        final_beams = []
429
        for elem in all_beams:
430
            elem.sort(key=lambda x: _score_fn(x), reverse=True)
431
            # just return the highest scoring prediction
432
            final_beams.append(elem[:1])
433

434
        final_input_ids = [
435
            torch.cat(elem[0]["sequence"], dim=1).squeeze(0) for elem in final_beams
436
        ]
437

438
        return final_input_ids, [_score_fn(fb[0]) for fb in final_beams]
439

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.