OpenBackdoor

Форк
0
/
demo_attack.py 
63 строки · 2.1 Кб
1
# Attack 
2
import os
3
import json
4
import argparse
5
import openbackdoor as ob 
6
from openbackdoor.data import load_dataset, get_dataloader, wrap_dataset
7
from openbackdoor.victims import load_victim
8
from openbackdoor.attackers import load_attacker
9
from openbackdoor.trainers import load_trainer
10
from openbackdoor.utils import set_config, logger, set_seed
11
from openbackdoor.utils.visualize import display_results
12

13

14
def parse_args():
15
    parser = argparse.ArgumentParser()
16
    parser.add_argument('--config_path', type=str, default='./configs/lws_config.json')
17
    parser.add_argument('--seed', type=int, default=42)
18
    args = parser.parse_args()
19
    return args
20

21

22
def main(config):
23
    # use the Hugging Face's datasets library 
24
    # change the SST dataset into 2-class  
25
    # choose a victim classification model 
26
    
27
    # choose Syntactic attacker and initialize it with default parameters 
28
    attacker = load_attacker(config["attacker"])
29
    victim = load_victim(config["victim"])
30
    # choose SST-2 as the evaluation data  
31
    target_dataset = load_dataset(**config["target_dataset"]) 
32
    poison_dataset = load_dataset(**config["poison_dataset"])
33

34

35
    # tmp={}
36
    # for key, value in poison_dataset.items():
37
    #     tmp[key] = value[:300]
38
    # poison_dataset = tmp
39

40
    # target_dataset = attacker.poison(victim, target_dataset)
41
    # launch attacks
42
    logger.info("Train backdoored model on {}".format(config["poison_dataset"]["name"]))
43
    backdoored_model = attacker.attack(victim, poison_dataset, config)
44
    if config["clean-tune"]:
45
        logger.info("Fine-tune model on {}".format(config["target_dataset"]["name"]))
46
        CleanTrainer = load_trainer(config["train"])
47
        backdoored_model = CleanTrainer.train(backdoored_model, target_dataset)
48
    
49
    logger.info("Evaluate backdoored model on {}".format(config["target_dataset"]["name"]))
50
    results = attacker.eval(backdoored_model, target_dataset)
51

52
    display_results(config, results)
53

54

55
if __name__=='__main__':
56
    args = parse_args()
57
    with open(args.config_path, 'r') as f:
58
        config = json.load(f)
59

60
    config = set_config(config)
61
    set_seed(args.seed)
62
    
63
    main(config)
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.