tensor-sensor
176 строк · 5.7 Кб
1"""
2MIT License
3
4Copyright (c) 2021 Terence Parr
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23"""
24from tsensor.parsing import *
25import re
26
27def check(s, expected_repr, expect_str=None):
28p = PyExprParser(s, hush_errors=False)
29t = p.parse()
30
31s = re.sub(r"\s+", "", s)
32result_str = str(t)
33result_str = re.sub(r"\s+", "", result_str)
34if expect_str:
35s = expect_str
36assert result_str==s
37
38result_repr = repr(t)
39result_repr = re.sub(r"\s+", "", result_repr)
40expected_repr = re.sub(r"\s+", "", expected_repr)
41# print("result", result_repr)
42# print("expected", expected)
43assert result_repr == expected_repr
44
45
46def test_assign():
47check("a = 3", "Assign(op=<EQUAL:=,2:3>,lhs=a,rhs=3)")
48
49
50def test_index():
51check("a[:,i,j]", "Index(arr=a, index=[:, i, j])")
52
53
54def test_index2():
55check("z = a[:]", "Assign(op=<EQUAL:=,2:3>,lhs=z,rhs=Index(arr=a,index=[:]))")
56
57def test_index3():
58check("g.W[:,:,1]", "Index(arr=Member(op=<DOT:.,1:2>,obj=g,member=W),index=[:,:,1])")
59
60def test_literal_list():
61check("[[1, 2], [3, 4]]",
62"ListLiteral(elems=[ListLiteral(elems=[1, 2]), ListLiteral(elems=[3, 4])])")
63
64
65def test_literal_array():
66check("np.array([[1, 2], [3, 4]])",
67"""
68Call(func=Member(op=<DOT:.,2:3>,obj=np,member=array),
69args=[ListLiteral(elems=[ListLiteral(elems=[1,2]),ListLiteral(elems=[3,4])])])
70""")
71
72
73def test_method():
74check("h = torch.tanh(h)",
75"Assign(op=<EQUAL:=,2:3>,lhs=h,rhs=Call(func=Member(op=<DOT:.,9:10>,obj=torch,member=tanh),args=[h]))")
76
77
78def test_method2():
79check("np.dot(b,b)",
80"Call(func=Member(op=<DOT:.,2:3>,obj=np,member=dot),args=[b,b])")
81
82
83def test_method3():
84check("y_pred = model(X)",
85"Assign(op=<EQUAL:=,7:8>,lhs=y_pred,rhs=Call(func=model,args=[X]))")
86
87
88def test_field():
89check("a.b", "Member(op=<DOT:.,1:2>,obj=a,member=b)")
90
91
92def test_member_func():
93check("a.f()", "Call(func=Member(op=<DOT:.,1:2>,obj=a,member=f),args=[])")
94
95
96def test_field2():
97check("a.b.c", "Member(op=<DOT:.,3:4>,obj=Member(op=<DOT:.,1:2>,obj=a,member=b),member=c)")
98
99
100def test_field_and_func():
101check("a.f().c", "Member(op=<DOT:.,5:6>,obj=Call(func=Member(op=<DOT:.,1:2>,obj=a,member=f),args=[]),member=c)")
102
103
104def test_parens():
105check("(a+b)*c", "BinaryOp(op=<STAR:*,5:6>,lhs=BinaryOp(op=<PLUS:+,2:3>,lhs=a,rhs=b),rhs=c)")
106
107
108def test_expr_in_arg_with_parens():
109check("h = torch.tanh( (1-z)*h + z*h_ )",
110"Assign(op=<EQUAL:=,2:3>,lhs=h,rhs=Call(func=Member(op=<DOT:.,9:10>,obj=torch,member=tanh),args=[BinaryOp(op=<PLUS:+,24:25>,lhs=BinaryOp(op=<STAR:*,21:22>,lhs=BinaryOp(op=<MINUS:-,18:19>,lhs=1,rhs=z),rhs=h),rhs=BinaryOp(op=<STAR:*,27:28>,lhs=z,rhs=h_))]))")
111
112
113def test_1tuple():
114check("(3,)", "TupleLiteral(elems=[3])")
115
116
117def test_2tuple():
118check("(3,4)", "TupleLiteral(elems=[3,4])")
119
120
121def test_2tuple_with_trailing_comma():
122check("(3,4,)", "TupleLiteral(elems=[3,4])")
123
124
125def test_field_array():
126check("a.b[34]", "Index(arr=Member(op=<DOT:.,1:2>,obj=a,member=b),index=[34])")
127
128
129def test_field_array_func():
130check("a.b[34].f()", "Call(func=Member(op=<DOT:.,7:8>,obj=Index(arr=Member(op=<DOT:.,1:2>,obj=a,member=b),index=[34]),member=f),args=[])")
131
132
133def test_arith():
134check("(1-z)*h + z*h_",
135"""BinaryOp(op=<PLUS:+,8:9>,
136lhs=BinaryOp(op=<STAR:*,5:6>,
137lhs=BinaryOp(op=<MINUS:-,2:3>,
138lhs=1,
139rhs=z),
140rhs=h),
141rhs=BinaryOp(op=<STAR:*,11:12>,lhs=z,rhs=h_))""")
142
143
144def test_pow():
145check("a**2",
146"""BinaryOp(op=<DOUBLESTAR:**,1:3>,lhs=a,rhs=2)""")
147
148
149def test_chained_pow():
150check("a**b**c",
151"""BinaryOp(op=<DOUBLESTAR:**,1:3>,lhs=a,rhs=BinaryOp(op=<DOUBLESTAR:**,4:6>,lhs=b,rhs=c))""")
152
153
154def test_chained_op():
155check("a + b + c",
156"""BinaryOp(op=<PLUS:+,6:7>,
157lhs=BinaryOp(op=<PLUS:+,2:3>, lhs=a, rhs=b),
158rhs=c)""")
159
160
161def test_matrix_arith():
162check("self.Whz@h + Uxz@x + bz",
163"""
164BinaryOp(op=<PLUS:+,19:20>,
165lhs=BinaryOp(op=<PLUS:+,11:12>,
166lhs=BinaryOp(op=<AT:@,8:9>,lhs=Member(op=<DOT:.,4:5>,obj=self,member=Whz),rhs=h),
167rhs=BinaryOp(op=<AT:@,16:17>,lhs=Uxz,rhs=x)),
168rhs=bz)
169""")
170
171def test_kwarg():
172check("torch.relu(torch.rand(size=(2000,)))",
173"""
174Call(func=Member(op=<DOT:.,5:6>,obj=torch,member=relu),
175args=[Call(func=Member(op=<DOT:.,16:17>,obj=torch,member=rand),
176args=[Assign(op=<EQUAL:=,26:27>,lhs=size,rhs=TupleLiteral(elems=[2000]))])])""")