colossalai

Форк
0
/
test_shape_consistency_apply.py 
76 строк · 2.6 Кб
1
import pytest
2
import torch
3

4
from colossalai.device.device_mesh import DeviceMesh
5
from colossalai.initialize import launch
6
from colossalai.logging import disable_existing_loggers
7
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
8
from colossalai.tensor.sharding_spec import ShardingSpec
9
from colossalai.testing import rerun_if_address_is_in_use, spawn
10

11

12
def check_apply(rank, world_size, port):
13
    disable_existing_loggers()
14
    launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
15

16
    physical_mesh_id = torch.arange(0, 4)
17
    mesh_shape = (2, 2)
18
    # [[0, 1,
19
    #  [2, 3]]
20
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
21
    entire_shape = torch.Size((4, 2))
22
    shape_consistency_manager = ShapeConsistencyManager()
23
    dim_partition_source = {0: [0]}
24
    dim_partition_target = {1: [0]}
25

26
    # DistSpec:
27
    #     shard_sequence: S0,R
28
    #     device_mesh_shape: (2, 2)
29
    sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
30

31
    # DistSpec:
32
    #     shard_sequence: R,S0
33
    #     device_mesh_shape: (2, 2)
34
    sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
35

36
    if rank in (0, 1):
37
        sharded_tensor_0 = torch.zeros(2, 1)
38
        sharded_tensor_1 = torch.ones(2, 1)
39
        # tensor([[0., 1.],
40
        #         [0., 1.]])
41
        tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
42
    if rank in (2, 3):
43
        sharded_tensor_0 = torch.ones(2, 1) * 2
44
        sharded_tensor_1 = torch.ones(2, 1) * 3
45
        # tensor([[2., 3.],
46
        #         [2., 3.]])
47
        tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
48

49
    if rank in (0, 1):
50
        # tensor([[0.],
51
        #         [0.],
52
        #         [2.],
53
        #         [2.]])
54
        tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
55
    if rank in (2, 3):
56
        # tensor([[1.],
57
        #         [1.],
58
        #         [3.],
59
        #         [3.]])
60
        tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
61

62
    tensor_to_comm.sharding_spec = sharding_spec_source
63
    tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
64
    assert tensor_to_comm.equal(tensor_to_check)
65
    assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)
66

67

68
@pytest.mark.dist
69
@rerun_if_address_is_in_use()
70
def test_apply():
71
    world_size = 4
72
    spawn(check_apply, world_size)
73

74

75
if __name__ == "__main__":
76
    test_apply()
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.