4
# Copyright 2023-present the HuggingFace Inc. team.
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
10
# http://www.apache.org/licenses/LICENSE-2.0
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
23
import torch.nn.init as init
25
from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict
28
def is_megatron_available() -> bool:
29
return importlib.util.find_spec("megatron") is not None
32
if is_megatron_available():
33
from megatron.core import parallel_state, tensor_parallel
34
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
35
from megatron.core.transformer.module import MegatronModule
36
from megatron.core.transformer.transformer_config import TransformerConfig
41
def initialize_distributed():
42
print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}")
43
torch.cuda.set_device(0)
44
init_method = "tcp://"
45
master_ip = os.getenv("MASTER_ADDR", "localhost")
46
master_port = os.getenv("MASTER_PORT", "6001")
47
init_method += master_ip + ":" + master_port
48
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank, init_method=init_method)
50
def destroy_model_parallel():
51
parallel_state.destroy_model_parallel()
52
torch.distributed.barrier()
54
def initialize_model_parallel(
55
tensor_model_parallel_size=1,
56
pipeline_model_parallel_size=1,
57
virtual_pipeline_model_parallel_size=None,
58
pipeline_model_parallel_split_rank=None,
60
parallel_state.destroy_model_parallel()
61
if not torch.distributed.is_initialized():
62
initialize_distributed()
63
parallel_state.initialize_model_parallel(
64
tensor_model_parallel_size,
65
pipeline_model_parallel_size,
66
virtual_pipeline_model_parallel_size,
67
pipeline_model_parallel_split_rank,
70
class DummyModule(MegatronModule):
71
def __init__(self, config: TransformerConfig):
72
super().__init__(config)
73
self.linear = tensor_parallel.ColumnParallelLinear(
77
init_method=init.xavier_normal_,
81
self.lm_head = tensor_parallel.RowParallelLinear(
85
init_method=init.xavier_normal_,
87
input_is_parallel=True,
91
def forward(self, input):
92
x = self.linear(input)[0]
93
x = self.lm_head(x)[0]
96
class TestMegatronLora(unittest.TestCase):
98
initialize_model_parallel(1, 1)
99
model_parallel_cuda_manual_seed(123)
100
transformer_config = {
103
"num_attention_heads": 4,
104
"use_cpu_initialization": True,
106
config = TransformerConfig(**transformer_config)
107
self.megatron_module = DummyModule(config=config).cuda()
108
self.dummy_module = copy.deepcopy(self.megatron_module).cuda()
110
lora_config = LoraConfig(
115
target_modules=["linear", "lm_head"],
116
megatron_config=config,
117
megatron_core="megatron.core",
119
self.megatron_module = get_peft_model(self.megatron_module, lora_config)
122
destroy_model_parallel()
124
def test_megatron_lora_module(self):
125
megatron_module = self.megatron_module
126
assert isinstance(megatron_module, PeftModel)
128
for name, module in megatron_module.named_modules():
129
if name.endswith("linear"):
130
assert hasattr(module, "lora_A")
131
assert hasattr(module, "lora_B")
132
if name.endswith("linear.lora_A.default"):
133
assert isinstance(module, torch.nn.Linear)
134
if name.endswith("linear.lora_B.default"):
135
assert isinstance(module, tensor_parallel.ColumnParallelLinear)
137
if name.endswith("lm_head.lora_A.default"):
138
assert isinstance(module, tensor_parallel.RowParallelLinear)
139
if name.endswith("lm_head.lora_B.default"):
140
assert isinstance(module, torch.nn.Linear)
142
def test_forward(self):
143
x = torch.ones((2, 4, 10)).cuda()
144
megatron_module_result = self.megatron_module(x)
145
dummt_module_result = self.dummy_module(x)
147
# Because lora_B is initialized with 0, the forward results of two models should be equal before backward.
148
assert megatron_module_result.equal(dummt_module_result)
150
def test_backward(self):
151
optimizer = torch.optim.AdamW(self.megatron_module.parameters())
152
loss_fn = torch.nn.CrossEntropyLoss()
154
x = torch.randn(2, 4, 10, requires_grad=True).cuda()
155
label = torch.randint(10, (2 * 4,)).cuda()
157
output = self.megatron_module(x)
158
output = output.reshape(2 * 4, 10)
159
loss = loss_fn(output, label)
164
def test_get_peft_model_state_dict(self):
165
peft_state_dict = get_peft_model_state_dict(self.megatron_module)
167
for key in peft_state_dict.keys():