OpenAttack
/
demo.py
55 строк · 1.6 Кб
1import OpenAttack2import nltk3from nltk.sentiment.vader import SentimentIntensityAnalyzer4import numpy as np5import datasets6
7def make_model():8class MyClassifier(OpenAttack.Classifier):9def __init__(self):10try:11self.model = SentimentIntensityAnalyzer()12except LookupError:13nltk.download('vader_lexicon')14self.model = SentimentIntensityAnalyzer()15
16def get_pred(self, input_):17return self.get_prob(input_).argmax(axis=1)18
19def get_prob(self, input_):20ret = []21for sent in input_:22res = self.model.polarity_scores(sent)23prob = (res["pos"] + 1e-6) / (res["neg"] + res["pos"] + 1e-6)24ret.append(np.array([1 - prob, prob]))25return np.array(ret)26return MyClassifier()27
28def dataset_mapping(x):29return {30"x": x["sentence"],31"y": 1 if x["label"] > 0.5 else 0,32}33
34def main():35
36print("New Attacker")37attacker = OpenAttack.attackers.PWWSAttacker()38
39print("Build model")40clsf = make_model()41
42dataset = datasets.load_dataset("sst", split="train[:100]").map(function=dataset_mapping)43
44print("Start attack")45attack_eval = OpenAttack.AttackEval( attacker, clsf, metrics=[46OpenAttack.metric.Fluency(),47OpenAttack.metric.GrammaticalErrors(),48OpenAttack.metric.SemanticSimilarity(),49OpenAttack.metric.EditDistance(),50OpenAttack.metric.ModificationRate()51] )52attack_eval.eval(dataset, visualize=True, progress_bar=True)53
54if __name__ == "__main__":55main()56