google-research
157 строк · 5.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for inference."""
17
18import functools19import os20from typing import Tuple21
22from absl.testing import absltest23import jax.numpy as jnp24import numpy as np25
26import resources27from scaling_transformer_inference_efficiency import checkpoint28from scaling_transformer_inference_efficiency import incremental29from scaling_transformer_inference_efficiency import inference30from scaling_transformer_inference_efficiency import partitioning31from scaling_transformer_inference_efficiency import sampling32from scaling_transformer_inference_efficiency import weights33from scaling_transformer_inference_efficiency.layers import layers_pjit34
35# pylint: disable = line-too-long
36# PaLM correctness test relies on internal checkpoints, return None in external code
37
38
39def golden_generation_unprompted():40# Found by running this test :inference_test at a CL that has been manually41# tested against PaLM-8B and was seen to produce meaningful English prose.42return incremental.Chunk(43np.array([[0, 5, 94, 19, 3, 9, 182, 514, 97, 12, 129, 3],44[0, 3, 9, 3, 60, 4312, 8, 3, 60, 4312, 5, 3]], np.int32),45np.array([12, 12], np.int32))46
47
48class InferenceTest(absltest.TestCase):49
50def test_nonincremental_score(self):51model, test_weights = load_toy_model()52golden_chunk, golden_token_scores = get_golden()53prefill_fn = model.instantiate_prefill_fn()54golden_chunk_result = model.prefill(55test_weights, prefill_fn, [], golden_chunk56).copy_to_host()57
58scores = np.where(golden_chunk.token_mask,59golden_chunk_result.per_token_scores, 0)60np.testing.assert_allclose(scores, golden_token_scores, rtol=0.02)61
62def test_incremental_score(self):63model, test_weights = load_toy_model()64golden_chunk, golden_token_scores = get_golden()65prefill_fn = model.instantiate_prefill_fn()66
67for split in [3, 6]:68a, b = golden_chunk.split_at(split)69result_a = model.prefill(test_weights, prefill_fn, [], a)70result_b = model.prefill(test_weights, prefill_fn, [result_a], b)71scores = jnp.concatenate([72result_a.copy_to_host().per_token_scores,73result_b.copy_to_host().per_token_scores74],75axis=1)76scores = np.where(golden_chunk.token_mask, scores, 0)77np.testing.assert_allclose(scores, golden_token_scores, rtol=0.02)78
79def test_unprompted_generation(self):80model, test_weights = load_toy_model()81num_samples = 282sample_ids = np.arange(num_samples)83steps = 1284temperature = 0.785generate_fn = model.instantiate_generating_fn(steps)86samples, _ = model.generate(87test_weights,88generate_fn,89[],90sample_ids,91sampling.SamplingHyperParams(temperature=temperature),92)93np.testing.assert_array_equal(94np.array(samples.tokens),95golden_generation_unprompted().tokens)96
97def test_unprompted_generation_incremental(self):98model, test_weights = load_toy_model()99num_samples = 2100sample_ids = np.arange(num_samples)101steps = 12102
103for split in [1, 3]:104generate_fn = model.instantiate_generating_fn(split)105
106samples_a, result_a = model.generate(107test_weights,108generate_fn,109[],110sample_ids,111sampling.SamplingHyperParams(temperature=0.7),112)113generate_fn = model.instantiate_generating_fn(steps - split)114samples_b, _ = model.generate(115test_weights,116generate_fn,117[result_a],118sample_ids,119sampling.SamplingHyperParams(temperature=0.7),120)121tokens = np.concatenate(122[np.array(samples_a.tokens), np.array(samples_b.tokens)], axis=1123)124np.testing.assert_array_equal(125tokens, golden_generation_unprompted().tokens126)127
128def test_prompted_generation_two_stages(self):129model, test_weights = load_toy_model()130num_samples = 2131sample_ids = np.arange(num_samples)132
133golden_chunk = golden_generation_unprompted()134# We'll prompt with the first 8 tokens (split into two chunks), and135# regenerate the last 4 tokens.136chunk_a, chunk_bc = golden_chunk.split_at(4)137chunk_b, chunk_c = chunk_bc.split_at(4)138# Additionally, we'll pad chunk_a, to test padding effects.139chunk_a = chunk_a.pad_to_length(6)140generate_fn = model.instantiate_generating_fn(chunk_c.tokens.shape[1])141prefill_fn = model.instantiate_prefill_fn()142
143result_a = model.prefill(test_weights, prefill_fn, [], chunk_a)144result_b = model.prefill(test_weights, prefill_fn, [result_a], chunk_b)145samples, _ = model.generate(146test_weights,147generate_fn,148[result_a, result_b],149sample_ids,150sampling.SamplingHyperParams(temperature=0.7),151)152
153np.testing.assert_array_equal(np.array(samples.tokens), chunk_c.tokens)154
155
156if __name__ == '__main__':157absltest.main()158