paddlenlp

Форк
0
/
test_sequence_parallel.py 
98 строк · 3.9 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import unittest
16

17
import numpy as np
18
import paddle
19
import paddle.distributed.fleet as fleet
20
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import PipelineParallel
21

22
from paddlenlp.transformers import GPTConfig, GPTForCausalLM, GPTForCausalLMPipe
23

24

25
class TestGPT(unittest.TestCase):
26
    def test_sequence_model(self):
27
        model_name_or_path = "gpt2-medium-en"
28
        seq_len = 1024
29
        batch_size = 2
30
        input_ids = paddle.arange(100, 100 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len])
31
        labels = paddle.arange(101, 101 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len])
32

33
        world_size = paddle.distributed.get_world_size()
34
        pp_degree = 2
35
        tp_degree = world_size // pp_degree
36
        strategy = fleet.DistributedStrategy()
37
        strategy.hybrid_configs = {
38
            "dp_degree": 1,
39
            "mp_degree": tp_degree,
40
            "pp_degree": pp_degree,
41
            "sharding_degree": 1,
42
        }
43
        strategy.pipeline_configs = {"enable_partial_send_recv": False if pp_degree > 1 else True}
44
        fleet.init(is_collective=True, strategy=strategy)
45
        hcg = fleet.get_hybrid_communicate_group()
46
        mp_group = hcg.get_model_parallel_group()
47
        tensor_parallel_rank = mp_group.rank
48

49
        if pp_degree > 1:
50
            model_class = GPTForCausalLMPipe
51
        else:
52
            model_class = GPTForCausalLM
53

54
        config = GPTConfig.from_pretrained(model_name_or_path)
55
        config.seq_length = seq_len
56
        config.use_flash_attention = False
57
        config.fuse_attention_qkv = False
58
        config.recompute_granularity = "full"
59
        config.virtual_pp_degree = 1
60
        config.use_recompute = False
61

62
        config.tensor_parallel_degree = tp_degree
63
        config.tensor_parallel_rank = tensor_parallel_rank
64
        config.tensor_parallel_output = False
65
        # when tp_degree > 1, sequence_parallel can be set to True
66
        config.sequence_parallel = True
67
        config.fuse_sequence_parallel_allreduce = False
68

69
        model = model_class.from_pretrained(model_name_or_path, config=config, dtype="float32")
70
        model.eval()
71

72
        if pp_degree > 1:
73
            pp_model = PipelineParallel(layers=model, hcg=hcg, strategy=strategy)
74
            pp_model.accumulate_steps = batch_size  # for micro_batch_size * acc_steps == batch_size
75
            ret_mp_pp = pp_model.eval_batch(data=[input_ids, labels], compute_loss=True)
76
        else:
77
            ret_mp_pp = model(input_ids=input_ids, labels=labels)[0]
78

79
        # run model for single device
80
        config.tensor_parallel_degree = 1
81
        config.tensor_parallel_rank = -1
82
        config.sequence_parallel = False
83
        single_model = GPTForCausalLM.from_pretrained(model_name_or_path, config=config, dtype="float32")
84
        single_model.eval()
85
        ret_single = single_model(input_ids=input_ids, labels=labels)[0]
86

87
        # output all results
88
        print(f"ret mp{tp_degree} pp{pp_degree}", float(ret_mp_pp))
89
        print("ret single", float(ret_single))
90

91
        diff = (ret_single - ret_mp_pp) / ret_single
92
        print(f"diff: {float(diff)}")
93
        np.testing.assert_allclose(float(ret_single), ret_mp_pp, rtol=1.5e-7)
94

95

96
if __name__ == "__main__":
97
    TestGPT().test_sequence_model()
98
# python -m paddle.distributed.launch --gpus 0,1,2,3  tests/test_pipeline_parallel.py
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.