colossalai

Форк
0
71 строка · 1.8 Кб
1
#!/usr/bin/env python
2
# -*- encoding: utf-8 -*-
3

4
import pytest
5
import torch
6
from checks_2d.check_layer_2d import (
7
    check_classifier_given_embed_weight,
8
    check_classifier_no_given_weight,
9
    check_embed,
10
    check_layernorm,
11
    check_linear,
12
    check_loss,
13
    check_patch_embed,
14
    check_vocab_parallel_classifier_given_embed_weight,
15
    check_vocab_parallel_classifier_no_given_weight,
16
    check_vocab_parallel_embed,
17
    check_vocab_parallel_loss,
18
)
19
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
20

21
from colossalai.legacy.core import global_context as gpc
22
from colossalai.legacy.initialize import launch
23
from colossalai.logging import disable_existing_loggers
24
from colossalai.testing import rerun_if_address_is_in_use, spawn
25

26
CONFIG = dict(
27
    parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="2d")),
28
)
29

30

31
def check_operations():
32
    check_AB()
33
    check_ABT()
34
    check_ATB()
35

36

37
def check_layer():
38
    check_linear()
39
    check_layernorm()
40
    check_embed()
41
    check_patch_embed()
42
    check_vocab_parallel_embed()
43
    check_classifier_no_given_weight()
44
    check_vocab_parallel_classifier_no_given_weight()
45
    check_classifier_given_embed_weight()
46
    check_vocab_parallel_classifier_given_embed_weight()
47
    check_loss()
48
    check_vocab_parallel_loss()
49

50

51
def check_layer_and_operation(rank, world_size, port):
52
    disable_existing_loggers()
53
    launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
54

55
    torch.backends.cuda.matmul.allow_tf32 = False
56
    torch.backends.cudnn.allow_tf32 = False
57
    torch.backends.cudnn.deterministic = True
58
    # check_operations()
59
    check_layer()
60
    gpc.destroy()
61
    torch.cuda.empty_cache()
62

63

64
@pytest.mark.dist
65
@rerun_if_address_is_in_use()
66
def test_2d():
67
    spawn(check_layer_and_operation, 4)
68

69

70
if __name__ == "__main__":
71
    test_2d()
72

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.