peft

Форк
0
/
test_lora_megatron.py 
168 строк · 6.2 Кб
1
#!/usr/bin/env python3
2

3
# coding=utf-8
4
# Copyright 2023-present the HuggingFace Inc. team.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17
import copy
18
import importlib
19
import os
20
import unittest
21

22
import torch
23
import torch.nn.init as init
24

25
from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict
26

27

28
def is_megatron_available() -> bool:
29
    return importlib.util.find_spec("megatron") is not None
30

31

32
if is_megatron_available():
33
    from megatron.core import parallel_state, tensor_parallel
34
    from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
35
    from megatron.core.transformer.module import MegatronModule
36
    from megatron.core.transformer.transformer_config import TransformerConfig
37

38
    world_size = 1
39
    rank = 0
40

41
    def initialize_distributed():
42
        print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}")
43
        torch.cuda.set_device(0)
44
        init_method = "tcp://"
45
        master_ip = os.getenv("MASTER_ADDR", "localhost")
46
        master_port = os.getenv("MASTER_PORT", "6001")
47
        init_method += master_ip + ":" + master_port
48
        torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank, init_method=init_method)
49

50
    def destroy_model_parallel():
51
        parallel_state.destroy_model_parallel()
52
        torch.distributed.barrier()
53

54
    def initialize_model_parallel(
55
        tensor_model_parallel_size=1,
56
        pipeline_model_parallel_size=1,
57
        virtual_pipeline_model_parallel_size=None,
58
        pipeline_model_parallel_split_rank=None,
59
    ):
60
        parallel_state.destroy_model_parallel()
61
        if not torch.distributed.is_initialized():
62
            initialize_distributed()
63
        parallel_state.initialize_model_parallel(
64
            tensor_model_parallel_size,
65
            pipeline_model_parallel_size,
66
            virtual_pipeline_model_parallel_size,
67
            pipeline_model_parallel_split_rank,
68
        )
69

70
    class DummyModule(MegatronModule):
71
        def __init__(self, config: TransformerConfig):
72
            super().__init__(config)
73
            self.linear = tensor_parallel.ColumnParallelLinear(
74
                input_size=10,
75
                output_size=10,
76
                config=config,
77
                init_method=init.xavier_normal_,
78
                bias=False,
79
                gather_output=False,
80
            )
81
            self.lm_head = tensor_parallel.RowParallelLinear(
82
                input_size=10,
83
                output_size=10,
84
                config=config,
85
                init_method=init.xavier_normal_,
86
                bias=False,
87
                input_is_parallel=True,
88
                skip_bias_add=True,
89
            )
90

91
        def forward(self, input):
92
            x = self.linear(input)[0]
93
            x = self.lm_head(x)[0]
94
            return x
95

96
    class TestMegatronLora(unittest.TestCase):
97
        def setUp(self):
98
            initialize_model_parallel(1, 1)
99
            model_parallel_cuda_manual_seed(123)
100
            transformer_config = {
101
                "num_layers": 2,
102
                "hidden_size": 12,
103
                "num_attention_heads": 4,
104
                "use_cpu_initialization": True,
105
            }
106
            config = TransformerConfig(**transformer_config)
107
            self.megatron_module = DummyModule(config=config).cuda()
108
            self.dummy_module = copy.deepcopy(self.megatron_module).cuda()
109

110
            lora_config = LoraConfig(
111
                lora_alpha=16,
112
                lora_dropout=0.1,
113
                r=64,
114
                bias="none",
115
                target_modules=["linear", "lm_head"],
116
                megatron_config=config,
117
                megatron_core="megatron.core",
118
            )
119
            self.megatron_module = get_peft_model(self.megatron_module, lora_config)
120

121
        def tearDown(self):
122
            destroy_model_parallel()
123

124
        def test_megatron_lora_module(self):
125
            megatron_module = self.megatron_module
126
            assert isinstance(megatron_module, PeftModel)
127

128
            for name, module in megatron_module.named_modules():
129
                if name.endswith("linear"):
130
                    assert hasattr(module, "lora_A")
131
                    assert hasattr(module, "lora_B")
132
                if name.endswith("linear.lora_A.default"):
133
                    assert isinstance(module, torch.nn.Linear)
134
                if name.endswith("linear.lora_B.default"):
135
                    assert isinstance(module, tensor_parallel.ColumnParallelLinear)
136

137
                if name.endswith("lm_head.lora_A.default"):
138
                    assert isinstance(module, tensor_parallel.RowParallelLinear)
139
                if name.endswith("lm_head.lora_B.default"):
140
                    assert isinstance(module, torch.nn.Linear)
141

142
        def test_forward(self):
143
            x = torch.ones((2, 4, 10)).cuda()
144
            megatron_module_result = self.megatron_module(x)
145
            dummt_module_result = self.dummy_module(x)
146

147
            # Because lora_B is initialized with 0, the forward results of two models should be equal before backward.
148
            assert megatron_module_result.equal(dummt_module_result)
149

150
        def test_backward(self):
151
            optimizer = torch.optim.AdamW(self.megatron_module.parameters())
152
            loss_fn = torch.nn.CrossEntropyLoss()
153

154
            x = torch.randn(2, 4, 10, requires_grad=True).cuda()
155
            label = torch.randint(10, (2 * 4,)).cuda()
156

157
            output = self.megatron_module(x)
158
            output = output.reshape(2 * 4, 10)
159
            loss = loss_fn(output, label)
160

161
            loss.backward()
162
            optimizer.step()
163

164
        def test_get_peft_model_state_dict(self):
165
            peft_state_dict = get_peft_model_state_dict(self.megatron_module)
166

167
            for key in peft_state_dict.keys():
168
                assert "lora" in key
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.