colossalai
71 строка · 1.8 Кб
1#!/usr/bin/env python
2# -*- encoding: utf-8 -*-
3
4import pytest5import torch6from checks_2d.check_layer_2d import (7check_classifier_given_embed_weight,8check_classifier_no_given_weight,9check_embed,10check_layernorm,11check_linear,12check_loss,13check_patch_embed,14check_vocab_parallel_classifier_given_embed_weight,15check_vocab_parallel_classifier_no_given_weight,16check_vocab_parallel_embed,17check_vocab_parallel_loss,18)
19from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB20
21from colossalai.legacy.core import global_context as gpc22from colossalai.legacy.initialize import launch23from colossalai.logging import disable_existing_loggers24from colossalai.testing import rerun_if_address_is_in_use, spawn25
26CONFIG = dict(27parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="2d")),28)
29
30
31def check_operations():32check_AB()33check_ABT()34check_ATB()35
36
37def check_layer():38check_linear()39check_layernorm()40check_embed()41check_patch_embed()42check_vocab_parallel_embed()43check_classifier_no_given_weight()44check_vocab_parallel_classifier_no_given_weight()45check_classifier_given_embed_weight()46check_vocab_parallel_classifier_given_embed_weight()47check_loss()48check_vocab_parallel_loss()49
50
51def check_layer_and_operation(rank, world_size, port):52disable_existing_loggers()53launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")54
55torch.backends.cuda.matmul.allow_tf32 = False56torch.backends.cudnn.allow_tf32 = False57torch.backends.cudnn.deterministic = True58# check_operations()59check_layer()60gpc.destroy()61torch.cuda.empty_cache()62
63
64@pytest.mark.dist65@rerun_if_address_is_in_use()66def test_2d():67spawn(check_layer_and_operation, 4)68
69
70if __name__ == "__main__":71test_2d()72