OpenBackdoor
/
demo_attack.py
63 строки · 2.1 Кб
1# Attack
2import os3import json4import argparse5import openbackdoor as ob6from openbackdoor.data import load_dataset, get_dataloader, wrap_dataset7from openbackdoor.victims import load_victim8from openbackdoor.attackers import load_attacker9from openbackdoor.trainers import load_trainer10from openbackdoor.utils import set_config, logger, set_seed11from openbackdoor.utils.visualize import display_results12
13
14def parse_args():15parser = argparse.ArgumentParser()16parser.add_argument('--config_path', type=str, default='./configs/lws_config.json')17parser.add_argument('--seed', type=int, default=42)18args = parser.parse_args()19return args20
21
22def main(config):23# use the Hugging Face's datasets library24# change the SST dataset into 2-class25# choose a victim classification model26
27# choose Syntactic attacker and initialize it with default parameters28attacker = load_attacker(config["attacker"])29victim = load_victim(config["victim"])30# choose SST-2 as the evaluation data31target_dataset = load_dataset(**config["target_dataset"])32poison_dataset = load_dataset(**config["poison_dataset"])33
34
35# tmp={}36# for key, value in poison_dataset.items():37# tmp[key] = value[:300]38# poison_dataset = tmp39
40# target_dataset = attacker.poison(victim, target_dataset)41# launch attacks42logger.info("Train backdoored model on {}".format(config["poison_dataset"]["name"]))43backdoored_model = attacker.attack(victim, poison_dataset, config)44if config["clean-tune"]:45logger.info("Fine-tune model on {}".format(config["target_dataset"]["name"]))46CleanTrainer = load_trainer(config["train"])47backdoored_model = CleanTrainer.train(backdoored_model, target_dataset)48
49logger.info("Evaluate backdoored model on {}".format(config["target_dataset"]["name"]))50results = attacker.eval(backdoored_model, target_dataset)51
52display_results(config, results)53
54
55if __name__=='__main__':56args = parse_args()57with open(args.config_path, 'r') as f:58config = json.load(f)59
60config = set_config(config)61set_seed(args.seed)62
63main(config)64