colossalai

Форк
0
/
test_backward.py 
62 строки · 1.7 Кб
1
import pytest
2
import timm.models as tmm
3
import torch
4
import torchvision.models as tm
5

6
from colossalai.fx._compatibility import is_compatible_with_meta
7

8
if is_compatible_with_meta():
9
    from colossalai.fx.profiler import MetaTensor
10

11
from colossalai.testing import clear_cache_before_run
12

13
tm_models = [
14
    tm.vgg11,
15
    tm.resnet18,
16
    tm.densenet121,
17
    tm.mobilenet_v3_small,
18
    tm.resnext50_32x4d,
19
    tm.wide_resnet50_2,
20
    tm.regnet_x_16gf,
21
    tm.mnasnet0_5,
22
    tm.efficientnet_b0,
23
]
24

25
tmm_models = [
26
    tmm.resnest.resnest50d,
27
    tmm.beit.beit_base_patch16_224,
28
    tmm.cait.cait_s24_224,
29
    tmm.efficientnet.efficientnetv2_m,
30
    tmm.resmlp_12_224,
31
    tmm.vision_transformer.vit_base_patch16_224,
32
    tmm.deit_base_distilled_patch16_224,
33
    tmm.convnext.convnext_base,
34
    tmm.vgg.vgg11,
35
    tmm.dpn.dpn68,
36
    tmm.densenet.densenet121,
37
    tmm.rexnet.rexnet_100,
38
    tmm.swin_transformer.swin_base_patch4_window7_224,
39
]
40

41

42
@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0")
43
@clear_cache_before_run()
44
def test_torchvision_models():
45
    for m in tm_models:
46
        model = m()
47
        data = torch.rand(100000, 3, 224, 224, device="meta")
48
        model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward()
49

50

51
@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0")
52
@clear_cache_before_run()
53
def test_timm_models():
54
    for m in tmm_models:
55
        model = m()
56
        data = torch.rand(100000, 3, 224, 224, device="meta")
57
        model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward()
58

59

60
if __name__ == "__main__":
61
    test_torchvision_models()
62
    test_timm_models()
63

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.