1
from typing import Dict, Optional
4
from common import make_picollama, run_and_check_merge
5
from transformers import AutoConfig
7
from mergekit.config import (
11
OutputSliceDefinition,
16
@pytest.fixture(scope="session")
17
def model_a(tmp_path_factory):
18
return make_picollama(tmp_path_factory.mktemp("model_a"))
21
@pytest.fixture(scope="session")
22
def model_b(tmp_path_factory):
23
return make_picollama(tmp_path_factory.mktemp("model_b"))
26
@pytest.fixture(scope="session")
27
def model_c(tmp_path_factory):
28
return make_picollama(tmp_path_factory.mktemp("model_c"))
32
def test_gpt2_copy(self):
33
config = MergeConfiguration(
34
merge_method="passthrough",
35
models=[InputModelDefinition(model="gpt2")],
38
run_and_check_merge(config)
40
def test_gpt2_stack(self):
41
config = MergeConfiguration(
42
merge_method="passthrough",
44
OutputSliceDefinition(
45
sources=[InputSliceDefinition(model="gpt2", layer_range=[0, 12])]
52
def _check_config_layers(p: str):
53
config = AutoConfig.from_pretrained(p)
54
assert config.n_layer == 24
56
run_and_check_merge(config, validate=_check_config_layers)
58
def test_linear_merge(self, model_a, model_b):
59
config = self.two_model_config(model_a, model_b, merge_method="linear")
60
run_and_check_merge(config)
62
def test_slerp_merge(self, model_a, model_b):
63
config = self.two_model_config(
64
model_a, model_b, merge_method="slerp", base_model=model_a
66
config.parameters = {"t": 0.35}
67
run_and_check_merge(config)
69
def test_task_arithmetic_merge(self, model_a, model_b, model_c):
70
config = self.two_model_config(
71
model_a, model_b, merge_method="task_arithmetic", base_model=model_c
73
run_and_check_merge(config)
75
def test_ties_merge(self, model_a, model_b, model_c):
76
config = self.two_model_config(
81
params={"density": 0.3},
83
run_and_check_merge(config)
85
def test_dare_ties_merge(self, model_a, model_b, model_c):
86
config = self.two_model_config(
89
merge_method="dare_ties",
91
params={"density": 0.66},
93
run_and_check_merge(config)
100
base_model: Optional[str] = None,
101
params: Optional[Dict[str, ParameterSetting]] = None,
103
config = MergeConfiguration(
104
merge_method=merge_method,
105
base_model=base_model,
107
InputModelDefinition(
109
parameters={"weight": 0.6},
111
InputModelDefinition(
113
parameters={"weight": 0.4},