pytorch

Форк
0
/
test_microbatch.py 
91 строка · 2.8 Кб
1
# Copyright (c) Meta Platforms, Inc. and affiliates
2
# Owner(s): ["oncall: distributed"]
3
from model_registry import ModelWithKwargs
4

5
import torch
6
from torch.distributed.pipelining import pipeline
7
from torch.distributed.pipelining.microbatch import (
8
    merge_chunks,
9
    split_args_kwargs_into_chunks,
10
    TensorChunkSpec,
11
)
12
from torch.testing._internal.common_utils import run_tests, TestCase
13

14

15
d_hid = 512
16
torch.manual_seed(0)
17

18

19
class MicrobatchTests(TestCase):
20
    def test_split_and_merge(self):
21
        x0 = torch.randn(128, d_hid)
22
        x1 = torch.randn(256, d_hid)
23
        x2 = torch.randn(512, d_hid)
24

25
        args = (x0, x1, x2)
26
        kwargs = {"x0": x0, "x1": x1, "x2": x2}
27

28
        # Default chunking: dim 0
29
        arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2)
30
        assert len(arg_chunks) == 2
31
        assert len(kwarg_chunks) == 2
32
        assert arg_chunks[0][0].shape == torch.Size([64, d_hid])
33
        assert arg_chunks[1][0].shape == torch.Size([64, d_hid])
34
        assert arg_chunks[0][1].shape == torch.Size([128, d_hid])
35
        assert arg_chunks[0][2].shape == torch.Size([256, d_hid])
36
        assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid])
37
        assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid])
38
        assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid])
39

40
        # Merge chunks back together
41
        merged_args = merge_chunks(
42
            arg_chunks,
43
            (TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)),
44
        )
45
        torch.testing.assert_close(merged_args, args)
46

47
        merged_kwargs = merge_chunks(
48
            kwarg_chunks,
49
            {
50
                "x0": TensorChunkSpec(0),
51
                "x1": TensorChunkSpec(0),
52
                "x2": TensorChunkSpec(0),
53
            },
54
        )
55
        torch.testing.assert_close(merged_kwargs, kwargs)
56
        print("Microbatch test passed")
57

58
    def test_chunk_spec(self):
59
        mod = ModelWithKwargs()
60
        batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE
61

62
        x = torch.randn(batch_size, d_hid)
63
        y = torch.randn(batch_size, d_hid)
64

65
        num_chunks = 4
66

67
        args_chunk_spec = TensorChunkSpec.from_tuple((0,))
68
        kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0})
69

70
        args_split, kwargs_split = split_args_kwargs_into_chunks(
71
            (x,),
72
            {"y": y},
73
            num_chunks,
74
            args_chunk_spec,
75
            kwargs_chunk_spec,
76
        )
77

78
        pipe = pipeline(
79
            mod,
80
            mb_args=args_split[0],
81
            mb_kwargs=kwargs_split[0],
82
        )
83

84
        ref = mod(x, y)
85
        out = pipe(x, y)[0]
86
        torch.testing.assert_close(out, ref)
87
        print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
88

89

90
if __name__ == "__main__":
91
    run_tests()
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.