mergekit

Форк
0
/
test_basic_merges.py 
120 строк · 3.4 Кб
1
from typing import Dict, Optional
2

3
import pytest
4
from common import make_picollama, run_and_check_merge
5
from transformers import AutoConfig
6

7
from mergekit.config import (
8
    InputModelDefinition,
9
    InputSliceDefinition,
10
    MergeConfiguration,
11
    OutputSliceDefinition,
12
    ParameterSetting,
13
)
14

15

16
@pytest.fixture(scope="session")
17
def model_a(tmp_path_factory):
18
    return make_picollama(tmp_path_factory.mktemp("model_a"))
19

20

21
@pytest.fixture(scope="session")
22
def model_b(tmp_path_factory):
23
    return make_picollama(tmp_path_factory.mktemp("model_b"))
24

25

26
@pytest.fixture(scope="session")
27
def model_c(tmp_path_factory):
28
    return make_picollama(tmp_path_factory.mktemp("model_c"))
29

30

31
class TestBasicMerges:
32
    def test_gpt2_copy(self):
33
        config = MergeConfiguration(
34
            merge_method="passthrough",
35
            models=[InputModelDefinition(model="gpt2")],
36
            dtype="bfloat16",
37
        )
38
        run_and_check_merge(config)
39

40
    def test_gpt2_stack(self):
41
        config = MergeConfiguration(
42
            merge_method="passthrough",
43
            slices=[
44
                OutputSliceDefinition(
45
                    sources=[InputSliceDefinition(model="gpt2", layer_range=[0, 12])]
46
                )
47
            ]
48
            * 2,
49
            dtype="bfloat16",
50
        )
51

52
        def _check_config_layers(p: str):
53
            config = AutoConfig.from_pretrained(p)
54
            assert config.n_layer == 24
55

56
        run_and_check_merge(config, validate=_check_config_layers)
57

58
    def test_linear_merge(self, model_a, model_b):
59
        config = self.two_model_config(model_a, model_b, merge_method="linear")
60
        run_and_check_merge(config)
61

62
    def test_slerp_merge(self, model_a, model_b):
63
        config = self.two_model_config(
64
            model_a, model_b, merge_method="slerp", base_model=model_a
65
        )
66
        config.parameters = {"t": 0.35}
67
        run_and_check_merge(config)
68

69
    def test_task_arithmetic_merge(self, model_a, model_b, model_c):
70
        config = self.two_model_config(
71
            model_a, model_b, merge_method="task_arithmetic", base_model=model_c
72
        )
73
        run_and_check_merge(config)
74

75
    def test_ties_merge(self, model_a, model_b, model_c):
76
        config = self.two_model_config(
77
            model_a,
78
            model_b,
79
            merge_method="ties",
80
            base_model=model_c,
81
            params={"density": 0.3},
82
        )
83
        run_and_check_merge(config)
84

85
    def test_dare_ties_merge(self, model_a, model_b, model_c):
86
        config = self.two_model_config(
87
            model_a,
88
            model_b,
89
            merge_method="dare_ties",
90
            base_model=model_c,
91
            params={"density": 0.66},
92
        )
93
        run_and_check_merge(config)
94

95
    def two_model_config(
96
        self,
97
        model_a,
98
        model_b,
99
        merge_method: str,
100
        base_model: Optional[str] = None,
101
        params: Optional[Dict[str, ParameterSetting]] = None,
102
    ):
103
        config = MergeConfiguration(
104
            merge_method=merge_method,
105
            base_model=base_model,
106
            models=[
107
                InputModelDefinition(
108
                    model=model_a,
109
                    parameters={"weight": 0.6},
110
                ),
111
                InputModelDefinition(
112
                    model=model_b,
113
                    parameters={"weight": 0.4},
114
                ),
115
            ],
116
            dtype="bfloat16",
117
            parameters=params,
118
        )
119

120
        return config
121

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.