tensor-sensor
123 строки · 3.3 Кб
1import sys
2import torch
3import numpy as np
4import graphviz
5import tempfile
6import matplotlib.patches as patches
7import matplotlib.pyplot as plt
8import matplotlib.font_manager as fm
9
10
11# print('\n'.join(str(f) for f in fm.fontManager.ttflist))
12import tsensor
13# from tsensor.viz import pyviz, astviz
14
15import torch
16import tsensor
17
18n = 200 # number of instances
19d = 764 # number of instance features
20nhidden = 256
21
22Whh = torch.eye(nhidden, nhidden) # Identity matrix
23Uxh = torch.randn(nhidden, d)
24bh = torch.zeros(nhidden, 1)
25h = torch.randn(nhidden, 1) # fake previous hidden state h
26# r = torch.randn(nhidden, 1) # fake this computation
27r = torch.randn(nhidden, 3) # fake this computation
28X = torch.rand(n,d) # fake input
29
30# Following code raises an exception
31with tsensor.explain(savefig="/Users/parrt/Desktop/toomany.png") as e:
32h = torch.tanh(Whh @ (r*h) + Uxh @ X.T + bh) # state vector update equation
33
34exit()
35
36def foo():
37# W = torch.rand(size=(2000, 2000))
38W = torch.rand(size=(2000, 2000, 10, 8))
39b = torch.rand(size=(2000, 1))
40h = torch.rand(size=(1_000_000,))
41x = torch.rand(size=(2000, 1))
42# g = tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))",
43# sys._getframe())
44frame = sys._getframe()
45frame = None
46g = tsensor.astviz("b = W[:,:,0,0]@b + (h+3).dot(h) + torch.abs(torch.tensor(34))",
47frame)
48g.view()
49
50#foo()
51
52class Linear:
53def __init__(self, d, n_neurons):
54self.W = torch.randn(n_neurons, d)
55self.b = torch.zeros(n_neurons, 1)
56def __call__(self, x):
57return self.W@x + self.b
58
59
60n = 200 # number of instances
61d = 764 # number of instance features
62n_neurons = 100 # how many neurons in this layer?
63# L = Linear(d,n_neurons)
64#
65# import tensorflow as tf
66# X = tf.random.uniform((n,d))
67# with tsensor.clarify(hush_errors=False):
68# Y = L(X)
69
70# g = tsensor.pyviz("Y = L(X)", hush_errors=False)
71# g.show()
72
73class GRU:
74def __init__(self):
75self.W = torch.rand(size=(2,20,2000,10))
76self.b = torch.rand(size=(20,1))
77# self.x = torch.tensor([4, 5]).reshape(2, 1)
78self.h = torch.rand(size=(1_000_000,))
79self.a = 3
80print(self.W.shape)
81print(self.W[:, :, 1].shape)
82
83def get(self):
84return torch.tensor([[1, 2], [3, 4]])
85
86# W = torch.tensor([[1, 2], [3, 4]])
87b = torch.rand(size=(2000,1))
88h = torch.rand(size=(1_000_000,2))
89x = torch.rand(size=(1_000_000,2))
90a = 3
91
92# foo = torch.rand(size=(2000,))
93# torch.relu(foo)
94
95g = GRU()
96
97# with tsensor.clarify():
98# tf.constant([1,2]) @ tf.constant([1,3])
99
100
101code = "b = g.W[0,:,:,1]@b+torch.zeros(200,1)+(h+3).dot(h)"
102# code = "torch.relu(foo)"
103# code = "np.dot(b,b)"
104# code = "b.T"
105g = tsensor.pyviz(code, fontname='Courier New', fontsize=16, dimfontsize=9,
106char_sep_scale=1.8, hush_errors=False)
107plt.tight_layout()
108plt.savefig("/tmp/t.svg", dpi=200, bbox_inches='tight', pad_inches=0)
109
110# W = torch.tensor([[1, 2], [3, 4]])
111# x = torch.tensor([4, 5]).reshape(2, 1)
112# with tsensor.explain():
113# b = torch.rand(size=(2000,))
114# torch.relu(b)
115
116
117# g = GRU()
118#
119# g1 = tsensor.astviz("b = g.W@b + torch.eye(3,3)")
120# g1.view()
121# g1 = tsensor.pyviz("b = g.W@b")
122# g1.view()
123# g2 = tsensor.astviz("b = g.W@b + g.h.dot(g.h) + torch.abs(torch.tensor(34))")
124