google-research

Форк
0
157 строк · 5.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for inference."""
17

18
import functools
19
import os
20
from typing import Tuple
21

22
from absl.testing import absltest
23
import jax.numpy as jnp
24
import numpy as np
25

26
 import resources
27
from scaling_transformer_inference_efficiency import checkpoint
28
from scaling_transformer_inference_efficiency import incremental
29
from scaling_transformer_inference_efficiency import inference
30
from scaling_transformer_inference_efficiency import partitioning
31
from scaling_transformer_inference_efficiency import sampling
32
from scaling_transformer_inference_efficiency import weights
33
from scaling_transformer_inference_efficiency.layers import layers_pjit
34

35
# pylint: disable = line-too-long
36
# PaLM correctness test relies on internal checkpoints, return None in external code
37

38

39
def golden_generation_unprompted():
40
  # Found by running this test :inference_test at a CL that has been manually
41
  # tested against PaLM-8B and was seen to produce meaningful English prose.
42
  return incremental.Chunk(
43
      np.array([[0, 5, 94, 19, 3, 9, 182, 514, 97, 12, 129, 3],
44
                [0, 3, 9, 3, 60, 4312, 8, 3, 60, 4312, 5, 3]], np.int32),
45
      np.array([12, 12], np.int32))
46

47

48
class InferenceTest(absltest.TestCase):
49

50
  def test_nonincremental_score(self):
51
    model, test_weights = load_toy_model()
52
    golden_chunk, golden_token_scores = get_golden()
53
    prefill_fn = model.instantiate_prefill_fn()
54
    golden_chunk_result = model.prefill(
55
        test_weights, prefill_fn, [], golden_chunk
56
    ).copy_to_host()
57

58
    scores = np.where(golden_chunk.token_mask,
59
                      golden_chunk_result.per_token_scores, 0)
60
    np.testing.assert_allclose(scores, golden_token_scores, rtol=0.02)
61

62
  def test_incremental_score(self):
63
    model, test_weights = load_toy_model()
64
    golden_chunk, golden_token_scores = get_golden()
65
    prefill_fn = model.instantiate_prefill_fn()
66

67
    for split in [3, 6]:
68
      a, b = golden_chunk.split_at(split)
69
      result_a = model.prefill(test_weights, prefill_fn, [], a)
70
      result_b = model.prefill(test_weights, prefill_fn, [result_a], b)
71
      scores = jnp.concatenate([
72
          result_a.copy_to_host().per_token_scores,
73
          result_b.copy_to_host().per_token_scores
74
      ],
75
                               axis=1)
76
      scores = np.where(golden_chunk.token_mask, scores, 0)
77
      np.testing.assert_allclose(scores, golden_token_scores, rtol=0.02)
78

79
  def test_unprompted_generation(self):
80
    model, test_weights = load_toy_model()
81
    num_samples = 2
82
    sample_ids = np.arange(num_samples)
83
    steps = 12
84
    temperature = 0.7
85
    generate_fn = model.instantiate_generating_fn(steps)
86
    samples, _ = model.generate(
87
        test_weights,
88
        generate_fn,
89
        [],
90
        sample_ids,
91
        sampling.SamplingHyperParams(temperature=temperature),
92
    )
93
    np.testing.assert_array_equal(
94
        np.array(samples.tokens),
95
        golden_generation_unprompted().tokens)
96

97
  def test_unprompted_generation_incremental(self):
98
    model, test_weights = load_toy_model()
99
    num_samples = 2
100
    sample_ids = np.arange(num_samples)
101
    steps = 12
102

103
    for split in [1, 3]:
104
      generate_fn = model.instantiate_generating_fn(split)
105

106
      samples_a, result_a = model.generate(
107
          test_weights,
108
          generate_fn,
109
          [],
110
          sample_ids,
111
          sampling.SamplingHyperParams(temperature=0.7),
112
      )
113
      generate_fn = model.instantiate_generating_fn(steps - split)
114
      samples_b, _ = model.generate(
115
          test_weights,
116
          generate_fn,
117
          [result_a],
118
          sample_ids,
119
          sampling.SamplingHyperParams(temperature=0.7),
120
      )
121
      tokens = np.concatenate(
122
          [np.array(samples_a.tokens), np.array(samples_b.tokens)], axis=1
123
      )
124
      np.testing.assert_array_equal(
125
          tokens, golden_generation_unprompted().tokens
126
      )
127

128
  def test_prompted_generation_two_stages(self):
129
    model, test_weights = load_toy_model()
130
    num_samples = 2
131
    sample_ids = np.arange(num_samples)
132

133
    golden_chunk = golden_generation_unprompted()
134
    # We'll prompt with the first 8 tokens (split into two chunks), and
135
    # regenerate the last 4 tokens.
136
    chunk_a, chunk_bc = golden_chunk.split_at(4)
137
    chunk_b, chunk_c = chunk_bc.split_at(4)
138
    # Additionally, we'll pad chunk_a, to test padding effects.
139
    chunk_a = chunk_a.pad_to_length(6)
140
    generate_fn = model.instantiate_generating_fn(chunk_c.tokens.shape[1])
141
    prefill_fn = model.instantiate_prefill_fn()
142

143
    result_a = model.prefill(test_weights, prefill_fn, [], chunk_a)
144
    result_b = model.prefill(test_weights, prefill_fn, [result_a], chunk_b)
145
    samples, _ = model.generate(
146
        test_weights,
147
        generate_fn,
148
        [result_a, result_b],
149
        sample_ids,
150
        sampling.SamplingHyperParams(temperature=0.7),
151
    )
152

153
    np.testing.assert_array_equal(np.array(samples.tokens), chunk_c.tokens)
154

155

156
if __name__ == '__main__':
157
  absltest.main()
158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.