gpt-neox

Форк
0
/
test_model_generation.py 
113 строк · 3.8 Кб
1
# Copyright (c) 2024, EleutherAI
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
"""
16
instantiate models, save checkpoints, load checkpoints, compare loaded parameters to saved parameters and compare forward pass outputs
17

18
This tests contain a relatively large number of functions. They are not split into separate tests because a lot of boilerplate (e.g. instantiate model) needs
19
to run in order to perform follow up tests. Joining in one test reduces runtime at the expense of decreased transparency of test results in case of failures.
20
"""
21

22

23
import os
24
import pytest
25
from tests.common import DistributedTest, model_setup, parametrize
26

27
PARAMS_TO_TEST = {
28
    "pipe_parallel_size,model_parallel_size,world_size": [
29
        [0, 1, 1],
30
        [0, 1, 2],
31
        [1, 2, 2],
32
        [0, 2, 2],
33
        [2, 1, 2],
34
    ],
35
    "top_p,temperature,top_k": [[0.0, 0.5, 0], [0.5, 0.0, 100], [0.5, 0.5, 0]],
36
    "prompt": ["", "hello world"],
37
    "fp16,fp32_allreduce": [
38
        [
39
            {
40
                "enabled": True,
41
                "type": "bfloat16",
42
                "loss_scale": 0,
43
                "loss_scale_window": 1000,
44
                "hysteresis": 2,
45
                "min_loss_scale": 1,
46
            },
47
            True,
48
        ],
49
        [
50
            {
51
                "enabled": True,
52
                "loss_scale": 0,
53
                "loss_scale_window": 1000,
54
                "hysteresis": 2,
55
                "min_loss_scale": 1,
56
            },
57
            False,
58
        ],
59
    ],
60
}
61

62
parameters, names = parametrize(
63
    PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
64
)
65

66

67
@pytest.mark.skip
68
@pytest.mark.parametrize("param_dict", parameters, ids=names)
69
def test_train(param_dict):
70
    t1 = run_generate_test_class()
71
    t1.run_generate_test(param_dict, param_dict.pop("prompt"))
72

73

74
class run_generate_test_class(DistributedTest):
75
    world_size = 2
76

77
    def run_generate_test(param_dict, prompt):
78
        from megatron.text_generation_utils import generate_samples_from_prompt
79
        from megatron.utils import is_mp_rank_0
80

81
        fixed_params = {
82
            "num_samples": 3,
83
            "maximum_tokens": 50,
84
            "make_vocab_size_divisible_by": 2,
85
            "sample_output_file": "test_sample_output.txt",
86
            "checkpoint_activations": False,
87
            "partition_activations": False,
88
            "no_load_optim": True,
89
        }
90

91
        param_dict.update(fixed_params)
92
        # TODO: we don't need to reinstantiate the model every time if we're only changing sampling settings - should be a workaround for this
93
        model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
94
        model.eval()
95

96
        prompts = [prompt for _ in range(args_loaded.num_samples)]
97
        output = generate_samples_from_prompt(
98
            neox_args=args_loaded,
99
            model=model,
100
            text=prompts,
101
            maximum_tokens=args_loaded.maximum_tokens,
102
            recompute=False,
103
            temperature=args_loaded.temperature,
104
            top_k=args_loaded.top_k,
105
            top_p=args_loaded.top_p,
106
        )
107

108
        # outputs only get generated on mp rank 0
109
        if is_mp_rank_0():
110
            assert len(output) == len(prompts)
111
            for prompt, out in zip(prompts, output):
112
                assert prompt == out["context"]
113
                assert len(out["text"]) > 0
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.