OpenAttack
72 строки · 2.6 Кб
1import OpenAttack
2import numpy as np
3import unittest, os
4
5class MetaClassifier(OpenAttack.Classifier):
6def __init__(self):
7self.last_meta = None
8
9def get_pred(self, input_):
10return self.get_prob(input_)
11
12def get_prob(self, input_):
13return self.get_grad([input_], [0])[0]
14
15def get_grad(self, input_, labels):
16self.last_meta = self.context.input
17return np.array([[1, 2, 3]]), None
18
19class TestMetaClassifier(unittest.TestCase):
20def test_get_pred(self):
21clsf = MetaClassifier()
22self.assertIsNone(clsf.last_meta)
23
24with self.assertRaises(TypeError):
25clsf.get_pred("I love apples")
26with self.assertRaises(TypeError):
27clsf.get_pred()
28with self.assertRaises(TypeError):
29clsf.get_pred(["I love apples"], "b", "c")
30self.assertIsNone(clsf.last_meta)
31clsf.set_context({}, None)
32clsf.get_pred(["I love apples"])
33self.assertDictEqual(clsf.last_meta, {})
34clsf.set_context({"THIS": "that"}, None)
35clsf.get_pred(["I love apples"])
36self.assertDictEqual(clsf.last_meta, {"THIS": "that"})
37
38def test_get_prob(self):
39clsf = MetaClassifier()
40self.assertIsNone(clsf.last_meta)
41with self.assertRaises(TypeError):
42clsf.get_prob("I love apples")
43with self.assertRaises(TypeError):
44clsf.get_prob()
45with self.assertRaises(TypeError):
46clsf.get_prob(["I love apples"], "b", "c")
47self.assertIsNone(clsf.last_meta)
48clsf.set_context({}, None)
49clsf.get_prob(["I love apples"])
50self.assertDictEqual(clsf.last_meta, {})
51clsf.set_context({"THIS": "that"}, None)
52clsf.get_prob(["I love apples"])
53self.assertDictEqual(clsf.last_meta, {"THIS": "that"})
54
55def test_get_grad(self):
56clsf = MetaClassifier()
57self.assertIsNone(clsf.last_meta)
58with self.assertRaises(TypeError):
59clsf.get_grad("I love apples")
60with self.assertRaises(TypeError):
61clsf.get_grad()
62with self.assertRaises(TypeError):
63clsf.get_grad(["I love apples"])
64with self.assertRaises(TypeError):
65clsf.get_grad(["I love apples"], "b", "c", "d")
66self.assertIsNone(clsf.last_meta)
67clsf.set_context({}, None)
68clsf.get_grad([["I", "love", "apples"]], [0])
69self.assertDictEqual(clsf.last_meta, {})
70clsf.set_context({"THIS": "that"}, None)
71clsf.get_grad([["I", "love", "apples"]], [0])
72self.assertDictEqual(clsf.last_meta, {"THIS": "that"})