colossalai

Форк
0
/
test_init_logical_pg.py 
37 строк · 1.0 Кб
1
import pytest
2
import torch
3
import torch.distributed as dist
4
from torch.distributed import ReduceOp
5

6
from colossalai.device.device_mesh import DeviceMesh
7
from colossalai.initialize import launch
8
from colossalai.testing import rerun_if_address_is_in_use, spawn
9

10

11
def check_layer(rank, world_size, port):
12
    launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
13

14
    physical_mesh_id = torch.arange(0, 4)
15
    assert rank == dist.get_rank()
16

17
    tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()
18
    mesh_shape = (2, 2)
19
    # [[0, 1,
20
    #  [2, 3]]
21
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
22

23
    for axis in range(len(mesh_shape)):
24
        tensor = torch.ones(4).cuda()
25
        pg = device_mesh.get_process_group(axis=axis)
26
        dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
27
        assert tensor.equal(tensor_to_check)
28

29

30
@pytest.mark.dist
31
@rerun_if_address_is_in_use()
32
def test_logical_pg():
33
    spawn(check_layer, 4)
34

35

36
if __name__ == "__main__":
37
    test_logical_pg()
38

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.