colossalai
48 строк · 1.3 Кб
1import torch2import torch.nn as nn3from torch.fx import GraphModule4
5from colossalai.fx.proxy import ColoProxy6from colossalai.fx.tracer.tracer import ColoTracer7from colossalai.testing import clear_cache_before_run8
9
10class Conv1D(nn.Module):11def __init__(self, nf, nx):12super().__init__()13self.nf = nf14w = torch.empty(nx, nf)15nn.init.normal_(w, std=0.02)16self.weight = nn.Parameter(w)17self.bias = nn.Parameter(torch.zeros(nf))18
19def forward(self, x):20size_out = x.shape[:-1] + (self.nf,)21x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)22x = x.view(size_out)23return x24
25
26@clear_cache_before_run()27def test_coloproxy():28tracer = ColoTracer()29model = Conv1D(3, 3)30input_sample = {"x": torch.rand(3, 3).to("meta")}31
32graph = tracer.trace(root=model, meta_args=input_sample)33gm = GraphModule(model, graph, model.__class__.__name__)34gm.recompile()35node = list(gm.graph.nodes)[0]36
37proxy = ColoProxy(node=node, tracer=tracer)38proxy.meta_data = torch.empty(4, 2, device="meta")39
40assert len(proxy) == 441assert proxy.shape[0] == 4 and proxy.shape[1] == 242assert proxy.dim() == 243assert proxy.dtype == torch.float3244assert proxy.size(0) == 445
46
47if __name__ == "__main__":48test_coloproxy()49