OpenAttack

Форк
0
/
demo_deo.py 
59 строк · 2.1 Кб
1
import OpenAttack
2
import nltk
3
from nltk.sentiment.vader import SentimentIntensityAnalyzer
4
import numpy as np
5
import datasets
6
import transformers
7

8
def make_model():
9
    class MyClassifier(OpenAttack.Classifier):
10
        def __init__(self):
11
            try:
12
                self.model = SentimentIntensityAnalyzer()
13
            except LookupError:
14
                nltk.download('vader_lexicon')
15
                self.model = SentimentIntensityAnalyzer()
16
        
17
        def get_pred(self, input_):
18
            return self.get_prob(input_).argmax(axis=1)
19

20
        def get_prob(self, input_):
21
            ret = []
22
            for sent in input_:
23
                res = self.model.polarity_scores(sent)
24
                prob = (res["pos"] + 1e-6) / (res["neg"] + res["pos"] + 1e-6)
25
                ret.append(np.array([1 - prob, prob]))
26
            return np.array(ret)
27
    return MyClassifier()
28

29
def dataset_mapping(x):
30
    return {
31
        "x": x["sentence"],
32
        "y": 1 if x["label"] > 0.5 else 0,
33
    }
34

35
def main():
36
    print("New Attacker")
37
    #attacker = OpenAttack.attackers.PWWSAttacker()
38
    print("Build model")
39
    #clsf = OpenAttack.loadVictim("BERT.SST")
40
    clsf = OpenAttack.DataManager.loadVictim("BERT.SST")
41
    #tokenizer = transformers.AutoTokenizer.from_pretrained("./data/Victim.BERT.SST")
42
    #model = transformers.AutoModelForSequenceClassification.from_pretrained("./data/Victim.BERT.SST", num_labels=2, output_hidden_states=True)
43
    #clsf = OpenAttack.classifiers.TransformersClassifier(model, tokenizer=tokenizer, max_length=100, embedding_layer=model.bert.embeddings.word_embeddings)
44

45
    dataset = datasets.load_dataset("sst", split="train[:100]").map(function=dataset_mapping)
46
    print("New Attacker")
47
    attacker = OpenAttack.attackers.UATAttacker()
48
    attacker.set_triggers(clsf, dataset)
49
    print("Start attack")
50
    attack_eval = OpenAttack.AttackEval( attacker, clsf, metrics=[
51
        OpenAttack.metric.Fluency(),
52
        OpenAttack.metric.GrammaticalErrors(),
53
        OpenAttack.metric.EditDistance(),
54
        OpenAttack.metric.ModificationRate()
55
    ] )
56
    attack_eval.eval(dataset, visualize=True, progress_bar=True)
57

58
if __name__ == "__main__":
59
    main()
60

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.