vllm

Форк
0
/
test_utils.py 
172 строки · 4.2 Кб
1
from collections import OrderedDict
2

3
from torch import nn
4

5
from vllm.utils import LRUCache
6
from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule)
7

8

9
def test_parse_fine_tuned_lora_name():
10
    fixture = {
11
        ("base_model.model.lm_head.lora_A.weight", "lm_head", True),
12
        ("base_model.model.lm_head.lora_B.weight", "lm_head", False),
13
        (
14
            "base_model.model.model.embed_tokens.lora_embedding_A",
15
            "model.embed_tokens",
16
            True,
17
        ),
18
        (
19
            "base_model.model.model.embed_tokens.lora_embedding_B",
20
            "model.embed_tokens",
21
            False,
22
        ),
23
        (
24
            "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
25
            "model.layers.9.mlp.down_proj",
26
            True,
27
        ),
28
        (
29
            "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
30
            "model.layers.9.mlp.down_proj",
31
            False,
32
        ),
33
    }
34
    for name, module_name, is_lora_a in fixture:
35
        assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
36

37

38
def test_replace_submodule():
39
    model = nn.Sequential(
40
        OrderedDict([
41
            ("dense1", nn.Linear(764, 100)),
42
            ("act1", nn.ReLU()),
43
            ("dense2", nn.Linear(100, 50)),
44
            (
45
                "seq1",
46
                nn.Sequential(
47
                    OrderedDict([
48
                        ("dense1", nn.Linear(100, 10)),
49
                        ("dense2", nn.Linear(10, 50)),
50
                    ])),
51
            ),
52
            ("act2", nn.ReLU()),
53
            ("output", nn.Linear(50, 10)),
54
            ("outact", nn.Sigmoid()),
55
        ]))
56

57
    sigmoid = nn.Sigmoid()
58

59
    replace_submodule(model, "act1", sigmoid)
60
    assert dict(model.named_modules())["act1"] == sigmoid
61

62
    dense2 = nn.Linear(1, 5)
63
    replace_submodule(model, "seq1.dense2", dense2)
64
    assert dict(model.named_modules())["seq1.dense2"] == dense2
65

66

67
class TestLRUCache(LRUCache):
68

69
    def _on_remove(self, key, value):
70
        if not hasattr(self, "_remove_counter"):
71
            self._remove_counter = 0
72
        self._remove_counter += 1
73

74

75
def test_lru_cache():
76
    cache = TestLRUCache(3)
77

78
    cache.put(1, 1)
79
    assert len(cache) == 1
80

81
    cache.put(1, 1)
82
    assert len(cache) == 1
83

84
    cache.put(2, 2)
85
    assert len(cache) == 2
86

87
    cache.put(3, 3)
88
    assert len(cache) == 3
89
    assert set(cache.cache) == {1, 2, 3}
90

91
    cache.put(4, 4)
92
    assert len(cache) == 3
93
    assert set(cache.cache) == {2, 3, 4}
94
    assert cache._remove_counter == 1
95
    assert cache.get(2) == 2
96

97
    cache.put(5, 5)
98
    assert set(cache.cache) == {2, 4, 5}
99
    assert cache._remove_counter == 2
100

101
    assert cache.pop(5) == 5
102
    assert len(cache) == 2
103
    assert set(cache.cache) == {2, 4}
104
    assert cache._remove_counter == 3
105

106
    cache.pop(10)
107
    assert len(cache) == 2
108
    assert set(cache.cache) == {2, 4}
109
    assert cache._remove_counter == 3
110

111
    cache.get(10)
112
    assert len(cache) == 2
113
    assert set(cache.cache) == {2, 4}
114
    assert cache._remove_counter == 3
115

116
    cache.put(6, 6)
117
    assert len(cache) == 3
118
    assert set(cache.cache) == {2, 4, 6}
119
    assert 2 in cache
120
    assert 4 in cache
121
    assert 6 in cache
122

123
    cache.remove_oldest()
124
    assert len(cache) == 2
125
    assert set(cache.cache) == {2, 6}
126
    assert cache._remove_counter == 4
127

128
    cache.clear()
129
    assert len(cache) == 0
130
    assert cache._remove_counter == 6
131

132
    cache._remove_counter = 0
133

134
    cache[1] = 1
135
    assert len(cache) == 1
136

137
    cache[1] = 1
138
    assert len(cache) == 1
139

140
    cache[2] = 2
141
    assert len(cache) == 2
142

143
    cache[3] = 3
144
    assert len(cache) == 3
145
    assert set(cache.cache) == {1, 2, 3}
146

147
    cache[4] = 4
148
    assert len(cache) == 3
149
    assert set(cache.cache) == {2, 3, 4}
150
    assert cache._remove_counter == 1
151
    assert cache[2] == 2
152

153
    cache[5] = 5
154
    assert set(cache.cache) == {2, 4, 5}
155
    assert cache._remove_counter == 2
156

157
    del cache[5]
158
    assert len(cache) == 2
159
    assert set(cache.cache) == {2, 4}
160
    assert cache._remove_counter == 3
161

162
    cache.pop(10)
163
    assert len(cache) == 2
164
    assert set(cache.cache) == {2, 4}
165
    assert cache._remove_counter == 3
166

167
    cache[6] = 6
168
    assert len(cache) == 3
169
    assert set(cache.cache) == {2, 4, 6}
170
    assert 2 in cache
171
    assert 4 in cache
172
    assert 6 in cache
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.