pytorch
91 строка · 2.8 Кб
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3from model_registry import ModelWithKwargs
4
5import torch
6from torch.distributed.pipelining import pipeline
7from torch.distributed.pipelining.microbatch import (
8merge_chunks,
9split_args_kwargs_into_chunks,
10TensorChunkSpec,
11)
12from torch.testing._internal.common_utils import run_tests, TestCase
13
14
15d_hid = 512
16torch.manual_seed(0)
17
18
19class MicrobatchTests(TestCase):
20def test_split_and_merge(self):
21x0 = torch.randn(128, d_hid)
22x1 = torch.randn(256, d_hid)
23x2 = torch.randn(512, d_hid)
24
25args = (x0, x1, x2)
26kwargs = {"x0": x0, "x1": x1, "x2": x2}
27
28# Default chunking: dim 0
29arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2)
30assert len(arg_chunks) == 2
31assert len(kwarg_chunks) == 2
32assert arg_chunks[0][0].shape == torch.Size([64, d_hid])
33assert arg_chunks[1][0].shape == torch.Size([64, d_hid])
34assert arg_chunks[0][1].shape == torch.Size([128, d_hid])
35assert arg_chunks[0][2].shape == torch.Size([256, d_hid])
36assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid])
37assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid])
38assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid])
39
40# Merge chunks back together
41merged_args = merge_chunks(
42arg_chunks,
43(TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)),
44)
45torch.testing.assert_close(merged_args, args)
46
47merged_kwargs = merge_chunks(
48kwarg_chunks,
49{
50"x0": TensorChunkSpec(0),
51"x1": TensorChunkSpec(0),
52"x2": TensorChunkSpec(0),
53},
54)
55torch.testing.assert_close(merged_kwargs, kwargs)
56print("Microbatch test passed")
57
58def test_chunk_spec(self):
59mod = ModelWithKwargs()
60batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE
61
62x = torch.randn(batch_size, d_hid)
63y = torch.randn(batch_size, d_hid)
64
65num_chunks = 4
66
67args_chunk_spec = TensorChunkSpec.from_tuple((0,))
68kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0})
69
70args_split, kwargs_split = split_args_kwargs_into_chunks(
71(x,),
72{"y": y},
73num_chunks,
74args_chunk_spec,
75kwargs_chunk_spec,
76)
77
78pipe = pipeline(
79mod,
80mb_args=args_split[0],
81mb_kwargs=kwargs_split[0],
82)
83
84ref = mod(x, y)
85out = pipe(x, y)[0]
86torch.testing.assert_close(out, ref)
87print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
88
89
90if __name__ == "__main__":
91run_tests()
92