OpenBackdoor

Форк
0
78 строк · 2.2 Кб
1
from .poisoner import Poisoner
2
import torch
3
import torch.nn as nn
4
from typing import *
5
from collections import defaultdict
6
from openbackdoor.utils import logger
7
from .utils.style.inference_utils import GPT2Generator
8
import os
9
from tqdm import tqdm
10

11

12

13
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
14
class StyleBkdPoisoner(Poisoner):
15
    r"""
16
        Poisoner for `StyleBkd <https://arxiv.org/pdf/2110.07139.pdf>`_
17
        
18
    Args:
19
        style_id (`int`, optional): The style id to be selected from `['bible', 'shakespeare', 'twitter', 'lyrics', 'poetry']`. Default to 0.
20
    """
21

22
    def __init__(
23
            self,
24
            style_id: Optional[int] = 0,
25
            **kwargs
26
    ):
27
        super().__init__(**kwargs)
28
        style_dict = ['bible', 'shakespeare', 'twitter', 'lyrics', 'poetry']
29
        base_path = os.path.dirname(__file__)
30
        style_chosen = style_dict[style_id]
31
        self.paraphraser = GPT2Generator(f"lievan/{style_chosen}", upper_length="same_5")
32
        self.paraphraser.modify_p(top_p=0.6)
33
        logger.info("Initializing Style poisoner, selected style is {}".format(style_chosen))
34

35

36

37

38
    def poison(self, data: list):
39
        with torch.no_grad():
40
            poisoned = []
41
            logger.info("Begin to transform sentence.")
42
            BATCH_SIZE = 32
43
            TOTAL_LEN = len(data) // BATCH_SIZE
44
            for i in tqdm(range(TOTAL_LEN+1)):
45
                select_texts = [text for text, _, _ in data[i*BATCH_SIZE:(i+1)*BATCH_SIZE]]
46
                transform_texts = self.transform_batch(select_texts)
47
                assert len(select_texts) == len(transform_texts)
48
                poisoned += [(text, self.target_label, 1) for text in transform_texts if not text.isspace()]
49

50
            return poisoned
51

52

53

54

55
    def transform(
56
            self,
57
            text: str
58
    ):
59
        r"""
60
            transform the style of a sentence.
61
            
62
        Args:
63
            text (`str`): Sentence to be transformed.
64
        """
65

66
        paraphrase = self.paraphraser.generate(text)
67
        return paraphrase
68

69

70

71
    def transform_batch(
72
            self,
73
            text_li: list,
74
    ):
75

76

77
        generations, _ = self.paraphraser.generate_batch(text_li)
78
        return generations
79

80

81

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.