OpenBackdoor
78 строк · 2.2 Кб
1from .poisoner import Poisoner
2import torch
3import torch.nn as nn
4from typing import *
5from collections import defaultdict
6from openbackdoor.utils import logger
7from .utils.style.inference_utils import GPT2Generator
8import os
9from tqdm import tqdm
10
11
12
13os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
14class StyleBkdPoisoner(Poisoner):
15r"""
16Poisoner for `StyleBkd <https://arxiv.org/pdf/2110.07139.pdf>`_
17
18Args:
19style_id (`int`, optional): The style id to be selected from `['bible', 'shakespeare', 'twitter', 'lyrics', 'poetry']`. Default to 0.
20"""
21
22def __init__(
23self,
24style_id: Optional[int] = 0,
25**kwargs
26):
27super().__init__(**kwargs)
28style_dict = ['bible', 'shakespeare', 'twitter', 'lyrics', 'poetry']
29base_path = os.path.dirname(__file__)
30style_chosen = style_dict[style_id]
31self.paraphraser = GPT2Generator(f"lievan/{style_chosen}", upper_length="same_5")
32self.paraphraser.modify_p(top_p=0.6)
33logger.info("Initializing Style poisoner, selected style is {}".format(style_chosen))
34
35
36
37
38def poison(self, data: list):
39with torch.no_grad():
40poisoned = []
41logger.info("Begin to transform sentence.")
42BATCH_SIZE = 32
43TOTAL_LEN = len(data) // BATCH_SIZE
44for i in tqdm(range(TOTAL_LEN+1)):
45select_texts = [text for text, _, _ in data[i*BATCH_SIZE:(i+1)*BATCH_SIZE]]
46transform_texts = self.transform_batch(select_texts)
47assert len(select_texts) == len(transform_texts)
48poisoned += [(text, self.target_label, 1) for text in transform_texts if not text.isspace()]
49
50return poisoned
51
52
53
54
55def transform(
56self,
57text: str
58):
59r"""
60transform the style of a sentence.
61
62Args:
63text (`str`): Sentence to be transformed.
64"""
65
66paraphrase = self.paraphraser.generate(text)
67return paraphrase
68
69
70
71def transform_batch(
72self,
73text_li: list,
74):
75
76
77generations, _ = self.paraphraser.generate_batch(text_li)
78return generations
79
80
81