colossalai

Форк
0
/
test_coloproxy.py 
48 строк · 1.3 Кб
1
import torch
2
import torch.nn as nn
3
from torch.fx import GraphModule
4

5
from colossalai.fx.proxy import ColoProxy
6
from colossalai.fx.tracer.tracer import ColoTracer
7
from colossalai.testing import clear_cache_before_run
8

9

10
class Conv1D(nn.Module):
11
    def __init__(self, nf, nx):
12
        super().__init__()
13
        self.nf = nf
14
        w = torch.empty(nx, nf)
15
        nn.init.normal_(w, std=0.02)
16
        self.weight = nn.Parameter(w)
17
        self.bias = nn.Parameter(torch.zeros(nf))
18

19
    def forward(self, x):
20
        size_out = x.shape[:-1] + (self.nf,)
21
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
22
        x = x.view(size_out)
23
        return x
24

25

26
@clear_cache_before_run()
27
def test_coloproxy():
28
    tracer = ColoTracer()
29
    model = Conv1D(3, 3)
30
    input_sample = {"x": torch.rand(3, 3).to("meta")}
31

32
    graph = tracer.trace(root=model, meta_args=input_sample)
33
    gm = GraphModule(model, graph, model.__class__.__name__)
34
    gm.recompile()
35
    node = list(gm.graph.nodes)[0]
36

37
    proxy = ColoProxy(node=node, tracer=tracer)
38
    proxy.meta_data = torch.empty(4, 2, device="meta")
39

40
    assert len(proxy) == 4
41
    assert proxy.shape[0] == 4 and proxy.shape[1] == 2
42
    assert proxy.dim() == 2
43
    assert proxy.dtype == torch.float32
44
    assert proxy.size(0) == 4
45

46

47
if __name__ == "__main__":
48
    test_coloproxy()
49

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.