OpenBackdoor
50 строк · 1.5 Кб
1from .poisoner import Poisoner
2import torch
3import torch.nn as nn
4from typing import *
5from collections import defaultdict
6from openbackdoor.utils import logger
7import random
8
9class BadNetsPoisoner(Poisoner):
10r"""
11Poisoner for `BadNets <https://arxiv.org/abs/1708.06733>`_
12
13Args:
14triggers (`List[str]`, optional): The triggers to insert in texts. Default to `['cf', 'mn', 'bb', 'tq']`.
15num_triggers (`int`, optional): Number of triggers to insert. Default to 1.
16"""
17def __init__(
18self,
19triggers: Optional[List[str]] = ["cf", "mn", "bb", "tq"],
20num_triggers: Optional[int] = 1,
21**kwargs
22):
23super().__init__(**kwargs)
24
25self.triggers = triggers
26self.num_triggers = num_triggers
27logger.info("Initializing BadNet poisoner, triggers are {}".format(" ".join(self.triggers)))
28
29def poison(self, data: list):
30poisoned = []
31for text, label, poison_label in data:
32poisoned.append((self.insert(text), self.target_label, 1))
33return poisoned
34
35def insert(
36self,
37text: str,
38):
39r"""
40Insert trigger(s) randomly in a sentence.
41
42Args:
43text (`str`): Sentence to insert trigger(s).
44"""
45words = text.split()
46for _ in range(self.num_triggers):
47insert_word = random.choice(self.triggers)
48position = random.randint(0, len(words))
49words.insert(position, insert_word)
50return " ".join(words)