OpenAttack

Форк
0
/
test_meta_classifier.py 
72 строки · 2.6 Кб
1
import OpenAttack
2
import numpy as np
3
import unittest, os
4

5
class MetaClassifier(OpenAttack.Classifier):
6
    def __init__(self):
7
        self.last_meta = None
8
    
9
    def get_pred(self, input_):
10
        return self.get_prob(input_)
11
    
12
    def get_prob(self, input_):
13
        return self.get_grad([input_], [0])[0]
14
    
15
    def get_grad(self, input_, labels):
16
        self.last_meta = self.context.input
17
        return np.array([[1, 2, 3]]), None
18

19
class TestMetaClassifier(unittest.TestCase):
20
    def test_get_pred(self):
21
        clsf = MetaClassifier()
22
        self.assertIsNone(clsf.last_meta)
23
        
24
        with self.assertRaises(TypeError):
25
            clsf.get_pred("I love apples")
26
        with self.assertRaises(TypeError):
27
            clsf.get_pred()
28
        with self.assertRaises(TypeError):
29
            clsf.get_pred(["I love apples"], "b", "c")
30
        self.assertIsNone(clsf.last_meta)
31
        clsf.set_context({}, None)
32
        clsf.get_pred(["I love apples"])
33
        self.assertDictEqual(clsf.last_meta, {})
34
        clsf.set_context({"THIS": "that"}, None)
35
        clsf.get_pred(["I love apples"])
36
        self.assertDictEqual(clsf.last_meta, {"THIS": "that"})
37

38
    def test_get_prob(self):
39
        clsf = MetaClassifier()
40
        self.assertIsNone(clsf.last_meta)
41
        with self.assertRaises(TypeError):
42
            clsf.get_prob("I love apples")
43
        with self.assertRaises(TypeError):
44
            clsf.get_prob()
45
        with self.assertRaises(TypeError):
46
            clsf.get_prob(["I love apples"], "b", "c")
47
        self.assertIsNone(clsf.last_meta)
48
        clsf.set_context({}, None)
49
        clsf.get_prob(["I love apples"])
50
        self.assertDictEqual(clsf.last_meta, {})
51
        clsf.set_context({"THIS": "that"}, None)
52
        clsf.get_prob(["I love apples"])
53
        self.assertDictEqual(clsf.last_meta, {"THIS": "that"})
54

55
    def test_get_grad(self):
56
        clsf = MetaClassifier()
57
        self.assertIsNone(clsf.last_meta)
58
        with self.assertRaises(TypeError):
59
            clsf.get_grad("I love apples")
60
        with self.assertRaises(TypeError):
61
            clsf.get_grad()
62
        with self.assertRaises(TypeError):
63
            clsf.get_grad(["I love apples"])
64
        with self.assertRaises(TypeError):
65
            clsf.get_grad(["I love apples"], "b", "c", "d")
66
        self.assertIsNone(clsf.last_meta)
67
        clsf.set_context({}, None)
68
        clsf.get_grad([["I", "love", "apples"]], [0])
69
        self.assertDictEqual(clsf.last_meta, {})
70
        clsf.set_context({"THIS": "that"}, None)
71
        clsf.get_grad([["I", "love", "apples"]], [0])
72
        self.assertDictEqual(clsf.last_meta, {"THIS": "that"})

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.