tensor-sensor
27 строк · 811.0 Байт
1import matplotlib.pyplot as plt
2import numpy as np
3import tsensor
4import torch
5import sys
6
7W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
8b = torch.tensor([9, 10]).reshape(2, 1)
9x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1)
10h = torch.tensor([1,2])
11
12# fig, ax = plt.subplots(1,1)
13# # view = tsensor.pyviz("b + x", ax=ax, legend=True)
14# # view.savefig("/Users/parrt/Desktop/foo.pdf")
15# plt.show()
16
17W = torch.rand(size=(2000,2000), dtype=torch.float64)
18b = torch.rand(size=(2000,1), dtype=torch.float64)
19h = torch.zeros(size=(1_000_000,), dtype=int)
20x = torch.rand(size=(2000,1))
21z = torch.rand(size=(2000,1), dtype=torch.complex64)
22g = tsensor.astviz("b = W@b + (h+3).dot(h) + z",
23sys._getframe()) # eval, highlight vectors
24g.view()
25
26# with tsensor.explain():
27# b + x
28
29