google-research
234 строки · 7.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Random or greedy sampling from the output logits of a model."""
17from typing import Any, Optional
18
19from flax import struct
20import jax
21from jax import lax
22from jax.experimental.shard_map import shard_map
23import jax.numpy as jnp
24from jax.sharding import Mesh
25from jax.sharding import PartitionSpec as P
26import typing_extensions
27
28from scaling_transformer_inference_efficiency import partitioning
29
30# from t5x import binary_search
31
32
33@struct.dataclass
34class SamplingHyperParams:
35temperature: Any
36top_k: Optional[Any] = None
37top_p: Optional[Any] = None
38
39@classmethod
40def physical_axes(cls):
41return SamplingHyperParams(temperature=P(), top_k=P(), top_p=P())
42
43
44def sample(
45logits,
46step_rngs,
47hyper_params,
48mesh):
49"""Samples from the output logits of a model.
50
51Args:
52logits: The output logits to sample from. float32[batch, vocab_size].
53step_rngs: For each batch element, the RNG state for sampling.
54jax.random.PRNGKey[batch]
55hyper_params: -
56mesh: For manual compat
57
58Returns:
59The selected samples, as token IDs. int32[batch].
60"""
61del mesh # used for manual mode compat
62# Ensure it is unsharded along vocab dimension
63# pylint: disable = protected-access
64logits = partitioning._with_sharding_constraint(
65logits, P('logit_batch', None)
66)
67# logits = binary_search.topp_mask(logits, hyper_params.top_p, -1e10)
68
69if hyper_params.top_k is not None:
70logits, top_k_indices = lax.approx_max_k(
71logits, hyper_params.top_k, recall_target=1.0
72)
73
74def sample_nonzero():
75# jax.random.categorical expects just one rng. We use vmap to extend it to
76# support a batch of rngs.
77
78return jnp.int32(
79jax.vmap(jax.random.categorical)(
80step_rngs, logits / hyper_params.temperature
81)
82)
83
84def sample_zero():
85return jnp.int32(jnp.argmax(logits, -1))
86
87# To avoid numerical instability when dividing by very small temperatures,
88# we sample deterministically (greedily) when the temperature is
89# sufficiently close to zero.
90sampled_logits = lax.cond(
91hyper_params.temperature > 1e-4, sample_nonzero, sample_zero
92)
93
94if hyper_params.top_k is not None:
95sampled_logits = jax.vmap(lambda indices, sampled: indices[sampled])(
96top_k_indices, sampled_logits # pylint: disable=undefined-variable
97)
98
99return partitioning._with_sharding_constraint(
100sampled_logits, P('logit_batch')
101)
102
103
104def sample_manual(
105logits,
106step_rngs,
107hyper_params,
108mesh,
109batch_unsharded = False,
110):
111"""Samples from the output logits when within xmap."""
112
113def lowered_fn(logits, step_rngs):
114y_axis = lax.psum(1, 'y')
115z_axis = lax.psum(1, 'z')
116yz_index = lax.axis_index('y') * z_axis + lax.axis_index('z')
117batch, _ = logits.shape
118
119with jax.named_scope('sample'):
120# logits: float32[batch, vocab.YZ] || float32[batch.X, vocab.YZ]
121
122if batch < z_axis:
123# float32[batch, vocab.YZ] -> float32[batch, vocab]
124# || float32[batch.X, vocab.YZ] -> float32[batch.X, vocab]
125logits = lax.all_gather(logits, ('y', 'z'), axis=1, tiled=True)
126all_gather_tokens = None if batch_unsharded else 'x'
127elif batch >= z_axis and batch < y_axis * z_axis:
128# float32[batch, vocab.YZ] -> float32[batch.Z, vocab]
129# || float32[batch.X, vocab.YZ] -> float32[batch.XZ, vocab]
130logits = lax.all_to_all(
131logits, 'z', split_axis=0, concat_axis=1, tiled=True
132)
133logits = lax.all_gather(logits, 'y', axis=1, tiled=True)
134split_size = batch // z_axis
135step_rngs = lax.dynamic_slice_in_dim(
136step_rngs, lax.axis_index('z') * split_size, (split_size), axis=0
137)
138all_gather_tokens = 'z' if batch_unsharded else ('x', 'z')
139elif batch >= y_axis * z_axis:
140# float32[batch, vocab.YZ] -> float32[batch.YZ, vocab]
141# || float32[batch.X, vocab.YZ] -> float32[batch.XYZ, vocab]
142logits = lax.all_to_all(
143logits, ('y', 'z'), split_axis=0, concat_axis=1, tiled=True
144)
145split_size = batch // y_axis // z_axis
146step_rngs = lax.dynamic_slice_in_dim(
147step_rngs, yz_index * split_size, (split_size), axis=0
148)
149all_gather_tokens = ('y', 'z') if batch_unsharded else ('x', 'y', 'z')
150else:
151raise NotImplementedError
152
153assert logits.shape[0] == step_rngs.shape[0]
154# TODO(sholto): Confirm this is the best way of doing it
155# logits = binary_search.topp_mask(logits, 0.9, -1e10)
156# TODO(sholto): maybe put t5x binary search back in
157sample_result = jnp.int32(
158jax.vmap(jax.random.categorical)(
159step_rngs, logits / hyper_params.temperature
160)
161)
162if all_gather_tokens is not None:
163# sample: int32[batch]
164sample_result = lax.all_gather(
165sample_result, all_gather_tokens, axis=0, tiled=True
166)
167return sample_result
168
169logit_specs = partitioning.logical_to_physical(P('logit_batch', 'vocab'))
170rng_specs = partitioning.logical_to_physical(P('logit_batch', None))
171# if it cannot be sharded as such, then do not
172# rng_specs = partitioning.safe_sharding(step_rngs, P(('x', 'y', 'z')), mesh)
173sample_result = shard_map(
174lowered_fn,
175mesh=mesh,
176in_specs=(logit_specs, rng_specs),
177out_specs=P(None),
178check_rep=False,
179)(logits, step_rngs)
180
181return sample_result
182
183
184def sample_manual_batch_unsharded(
185logits,
186step_rngs,
187hyper_params,
188mesh):
189"""Samples from output logits within xmap, with batch unshardedable.
190
191Args:
192logits: [batch, vocab.YZX]
193step_rngs: [batch]
194hyper_params: -
195mesh: for manual collectives
196
197Returns:
198sample" int32[batch]
199"""
200def lowered_fn(logits, step_rngs):
201with jax.named_scope('sample'):
202# multi-part all gather not implemented for xmap in jit see lax.parallel
203logits = lax.all_gather(logits, 'x', axis=1, tiled=True)
204logits = lax.all_gather(logits, 'z', axis=1, tiled=True)
205logits = lax.all_gather(logits, 'y', axis=1, tiled=True)
206assert logits.shape[0] == step_rngs.shape[0]
207sample_result = jnp.int32(
208jax.vmap(jax.random.categorical)(
209step_rngs, logits / hyper_params.temperature
210)
211)
212return sample_result
213logit_specs = partitioning.logical_to_physical(P('logit_batch', 'vocab'))
214sample_result = shard_map(
215lowered_fn,
216mesh=mesh,
217in_specs=(logit_specs, P(None)),
218out_specs=P(None),
219check_rep=False,
220)(logits, step_rngs)
221return sample_result
222
223
224class SampleFn(typing_extensions.Protocol):
225"""A function providing a forwards pass through a model."""
226
227def __call__(
228self,
229logits,
230step_rngs,
231hyper_params,
232mesh,
233):
234Ellipsis
235