1
from collections import OrderedDict
5
from vllm.utils import LRUCache
6
from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule)
9
def test_parse_fine_tuned_lora_name():
11
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
12
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
14
"base_model.model.model.embed_tokens.lora_embedding_A",
19
"base_model.model.model.embed_tokens.lora_embedding_B",
24
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
25
"model.layers.9.mlp.down_proj",
29
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
30
"model.layers.9.mlp.down_proj",
34
for name, module_name, is_lora_a in fixture:
35
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
38
def test_replace_submodule():
39
model = nn.Sequential(
41
("dense1", nn.Linear(764, 100)),
43
("dense2", nn.Linear(100, 50)),
48
("dense1", nn.Linear(100, 10)),
49
("dense2", nn.Linear(10, 50)),
53
("output", nn.Linear(50, 10)),
54
("outact", nn.Sigmoid()),
57
sigmoid = nn.Sigmoid()
59
replace_submodule(model, "act1", sigmoid)
60
assert dict(model.named_modules())["act1"] == sigmoid
62
dense2 = nn.Linear(1, 5)
63
replace_submodule(model, "seq1.dense2", dense2)
64
assert dict(model.named_modules())["seq1.dense2"] == dense2
67
class TestLRUCache(LRUCache):
69
def _on_remove(self, key, value):
70
if not hasattr(self, "_remove_counter"):
71
self._remove_counter = 0
72
self._remove_counter += 1
76
cache = TestLRUCache(3)
79
assert len(cache) == 1
82
assert len(cache) == 1
85
assert len(cache) == 2
88
assert len(cache) == 3
89
assert set(cache.cache) == {1, 2, 3}
92
assert len(cache) == 3
93
assert set(cache.cache) == {2, 3, 4}
94
assert cache._remove_counter == 1
95
assert cache.get(2) == 2
98
assert set(cache.cache) == {2, 4, 5}
99
assert cache._remove_counter == 2
101
assert cache.pop(5) == 5
102
assert len(cache) == 2
103
assert set(cache.cache) == {2, 4}
104
assert cache._remove_counter == 3
107
assert len(cache) == 2
108
assert set(cache.cache) == {2, 4}
109
assert cache._remove_counter == 3
112
assert len(cache) == 2
113
assert set(cache.cache) == {2, 4}
114
assert cache._remove_counter == 3
117
assert len(cache) == 3
118
assert set(cache.cache) == {2, 4, 6}
123
cache.remove_oldest()
124
assert len(cache) == 2
125
assert set(cache.cache) == {2, 6}
126
assert cache._remove_counter == 4
129
assert len(cache) == 0
130
assert cache._remove_counter == 6
132
cache._remove_counter = 0
135
assert len(cache) == 1
138
assert len(cache) == 1
141
assert len(cache) == 2
144
assert len(cache) == 3
145
assert set(cache.cache) == {1, 2, 3}
148
assert len(cache) == 3
149
assert set(cache.cache) == {2, 3, 4}
150
assert cache._remove_counter == 1
154
assert set(cache.cache) == {2, 4, 5}
155
assert cache._remove_counter == 2
158
assert len(cache) == 2
159
assert set(cache.cache) == {2, 4}
160
assert cache._remove_counter == 3
163
assert len(cache) == 2
164
assert set(cache.cache) == {2, 4}
165
assert cache._remove_counter == 3
168
assert len(cache) == 3
169
assert set(cache.cache) == {2, 4, 6}