tensor-sensor

Форк
0
/
test_parser.py 
176 строк · 5.7 Кб
1
"""
2
MIT License
3

4
Copyright (c) 2021 Terence Parr
5

6
Permission is hereby granted, free of charge, to any person obtaining a copy
7
of this software and associated documentation files (the "Software"), to deal
8
in the Software without restriction, including without limitation the rights
9
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
copies of the Software, and to permit persons to whom the Software is
11
furnished to do so, subject to the following conditions:
12

13
The above copyright notice and this permission notice shall be included in all
14
copies or substantial portions of the Software.
15

16
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
SOFTWARE.
23
"""
24
from tsensor.parsing import *
25
import re
26

27
def check(s, expected_repr, expect_str=None):
28
    p = PyExprParser(s, hush_errors=False)
29
    t = p.parse()
30

31
    s = re.sub(r"\s+", "", s)
32
    result_str = str(t)
33
    result_str = re.sub(r"\s+", "", result_str)
34
    if expect_str:
35
        s = expect_str
36
    assert result_str==s
37

38
    result_repr = repr(t)
39
    result_repr = re.sub(r"\s+", "", result_repr)
40
    expected_repr = re.sub(r"\s+", "", expected_repr)
41
    # print("result", result_repr)
42
    # print("expected", expected)
43
    assert result_repr == expected_repr
44

45

46
def test_assign():
47
    check("a = 3", "Assign(op=<EQUAL:=,2:3>,lhs=a,rhs=3)")
48

49

50
def test_index():
51
    check("a[:,i,j]", "Index(arr=a, index=[:, i, j])")
52

53

54
def test_index2():
55
    check("z = a[:]", "Assign(op=<EQUAL:=,2:3>,lhs=z,rhs=Index(arr=a,index=[:]))")
56

57
def test_index3():
58
    check("g.W[:,:,1]", "Index(arr=Member(op=<DOT:.,1:2>,obj=g,member=W),index=[:,:,1])")
59

60
def test_literal_list():
61
    check("[[1, 2], [3, 4]]",
62
          "ListLiteral(elems=[ListLiteral(elems=[1, 2]), ListLiteral(elems=[3, 4])])")
63

64

65
def test_literal_array():
66
    check("np.array([[1, 2], [3, 4]])",
67
          """
68
          Call(func=Member(op=<DOT:.,2:3>,obj=np,member=array),
69
               args=[ListLiteral(elems=[ListLiteral(elems=[1,2]),ListLiteral(elems=[3,4])])])
70
          """)
71

72

73
def test_method():
74
    check("h = torch.tanh(h)",
75
          "Assign(op=<EQUAL:=,2:3>,lhs=h,rhs=Call(func=Member(op=<DOT:.,9:10>,obj=torch,member=tanh),args=[h]))")
76

77

78
def test_method2():
79
    check("np.dot(b,b)",
80
          "Call(func=Member(op=<DOT:.,2:3>,obj=np,member=dot),args=[b,b])")
81

82

83
def test_method3():
84
    check("y_pred = model(X)",
85
          "Assign(op=<EQUAL:=,7:8>,lhs=y_pred,rhs=Call(func=model,args=[X]))")
86

87

88
def test_field():
89
    check("a.b", "Member(op=<DOT:.,1:2>,obj=a,member=b)")
90

91

92
def test_member_func():
93
    check("a.f()", "Call(func=Member(op=<DOT:.,1:2>,obj=a,member=f),args=[])")
94

95

96
def test_field2():
97
    check("a.b.c", "Member(op=<DOT:.,3:4>,obj=Member(op=<DOT:.,1:2>,obj=a,member=b),member=c)")
98

99

100
def test_field_and_func():
101
    check("a.f().c", "Member(op=<DOT:.,5:6>,obj=Call(func=Member(op=<DOT:.,1:2>,obj=a,member=f),args=[]),member=c)")
102

103

104
def test_parens():
105
    check("(a+b)*c", "BinaryOp(op=<STAR:*,5:6>,lhs=BinaryOp(op=<PLUS:+,2:3>,lhs=a,rhs=b),rhs=c)")
106

107

108
def test_expr_in_arg_with_parens():
109
    check("h = torch.tanh( (1-z)*h + z*h_ )",
110
          "Assign(op=<EQUAL:=,2:3>,lhs=h,rhs=Call(func=Member(op=<DOT:.,9:10>,obj=torch,member=tanh),args=[BinaryOp(op=<PLUS:+,24:25>,lhs=BinaryOp(op=<STAR:*,21:22>,lhs=BinaryOp(op=<MINUS:-,18:19>,lhs=1,rhs=z),rhs=h),rhs=BinaryOp(op=<STAR:*,27:28>,lhs=z,rhs=h_))]))")
111

112

113
def test_1tuple():
114
    check("(3,)", "TupleLiteral(elems=[3])")
115

116

117
def test_2tuple():
118
    check("(3,4)", "TupleLiteral(elems=[3,4])")
119

120

121
def test_2tuple_with_trailing_comma():
122
    check("(3,4,)", "TupleLiteral(elems=[3,4])")
123

124

125
def test_field_array():
126
    check("a.b[34]", "Index(arr=Member(op=<DOT:.,1:2>,obj=a,member=b),index=[34])")
127

128

129
def test_field_array_func():
130
    check("a.b[34].f()", "Call(func=Member(op=<DOT:.,7:8>,obj=Index(arr=Member(op=<DOT:.,1:2>,obj=a,member=b),index=[34]),member=f),args=[])")
131

132

133
def test_arith():
134
    check("(1-z)*h + z*h_",
135
          """BinaryOp(op=<PLUS:+,8:9>,
136
                      lhs=BinaryOp(op=<STAR:*,5:6>,
137
                                 lhs=BinaryOp(op=<MINUS:-,2:3>,
138
                                              lhs=1,
139
                                              rhs=z),
140
                                 rhs=h),
141
                      rhs=BinaryOp(op=<STAR:*,11:12>,lhs=z,rhs=h_))""")
142

143

144
def test_pow():
145
    check("a**2",
146
          """BinaryOp(op=<DOUBLESTAR:**,1:3>,lhs=a,rhs=2)""")
147

148

149
def test_chained_pow():
150
    check("a**b**c",
151
          """BinaryOp(op=<DOUBLESTAR:**,1:3>,lhs=a,rhs=BinaryOp(op=<DOUBLESTAR:**,4:6>,lhs=b,rhs=c))""")
152

153

154
def test_chained_op():
155
    check("a + b + c",
156
          """BinaryOp(op=<PLUS:+,6:7>,
157
                      lhs=BinaryOp(op=<PLUS:+,2:3>, lhs=a, rhs=b),
158
                      rhs=c)""")
159

160

161
def test_matrix_arith():
162
    check("self.Whz@h + Uxz@x + bz",
163
          """
164
          BinaryOp(op=<PLUS:+,19:20>,
165
                   lhs=BinaryOp(op=<PLUS:+,11:12>,
166
                                lhs=BinaryOp(op=<AT:@,8:9>,lhs=Member(op=<DOT:.,4:5>,obj=self,member=Whz),rhs=h),
167
                                rhs=BinaryOp(op=<AT:@,16:17>,lhs=Uxz,rhs=x)),
168
                   rhs=bz)
169
          """)
170

171
def test_kwarg():
172
    check("torch.relu(torch.rand(size=(2000,)))",
173
          """
174
          Call(func=Member(op=<DOT:.,5:6>,obj=torch,member=relu),
175
               args=[Call(func=Member(op=<DOT:.,16:17>,obj=torch,member=rand),
176
                          args=[Assign(op=<EQUAL:=,26:27>,lhs=size,rhs=TupleLiteral(elems=[2000]))])])""")

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.