google-research

Форк
0
234 строки · 7.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Random or greedy sampling from the output logits of a model."""
17
from typing import Any, Optional
18

19
from flax import struct
20
import jax
21
from jax import lax
22
from jax.experimental.shard_map import shard_map
23
import jax.numpy as jnp
24
from jax.sharding import Mesh
25
from jax.sharding import PartitionSpec as P
26
import typing_extensions
27

28
from scaling_transformer_inference_efficiency import partitioning
29

30
# from t5x import binary_search
31

32

33
@struct.dataclass
34
class SamplingHyperParams:
35
  temperature: Any
36
  top_k: Optional[Any] = None
37
  top_p: Optional[Any] = None
38

39
  @classmethod
40
  def physical_axes(cls):
41
    return SamplingHyperParams(temperature=P(), top_k=P(), top_p=P())
42

43

44
def sample(
45
    logits,
46
    step_rngs,
47
    hyper_params,
48
    mesh):
49
  """Samples from the output logits of a model.
50

51
  Args:
52
    logits: The output logits to sample from. float32[batch, vocab_size].
53
    step_rngs: For each batch element, the RNG state for sampling.
54
      jax.random.PRNGKey[batch]
55
    hyper_params: -
56
    mesh: For manual compat
57

58
  Returns:
59
    The selected samples, as token IDs. int32[batch].
60
  """
61
  del mesh  # used for manual mode compat
62
  # Ensure it is unsharded along vocab dimension
63
  # pylint: disable = protected-access
64
  logits = partitioning._with_sharding_constraint(
65
      logits, P('logit_batch', None)
66
  )
67
  # logits = binary_search.topp_mask(logits, hyper_params.top_p, -1e10)
68

69
  if hyper_params.top_k is not None:
70
    logits, top_k_indices = lax.approx_max_k(
71
        logits, hyper_params.top_k, recall_target=1.0
72
    )
73

74
  def sample_nonzero():
75
    # jax.random.categorical expects just one rng. We use vmap to extend it to
76
    # support a batch of rngs.
77

78
    return jnp.int32(
79
        jax.vmap(jax.random.categorical)(
80
            step_rngs, logits / hyper_params.temperature
81
        )
82
    )
83

84
  def sample_zero():
85
    return jnp.int32(jnp.argmax(logits, -1))
86

87
  # To avoid numerical instability when dividing by very small temperatures,
88
  # we sample deterministically (greedily) when the temperature is
89
  # sufficiently close to zero.
90
  sampled_logits = lax.cond(
91
      hyper_params.temperature > 1e-4, sample_nonzero, sample_zero
92
  )
93

94
  if hyper_params.top_k is not None:
95
    sampled_logits = jax.vmap(lambda indices, sampled: indices[sampled])(
96
        top_k_indices, sampled_logits  # pylint: disable=undefined-variable
97
    )
98

99
  return partitioning._with_sharding_constraint(
100
      sampled_logits, P('logit_batch')
101
  )
102

103

104
def sample_manual(
105
    logits,
106
    step_rngs,
107
    hyper_params,
108
    mesh,
109
    batch_unsharded = False,
110
):
111
  """Samples from the output logits when within xmap."""
112

113
  def lowered_fn(logits, step_rngs):
114
    y_axis = lax.psum(1, 'y')
115
    z_axis = lax.psum(1, 'z')
116
    yz_index = lax.axis_index('y') * z_axis + lax.axis_index('z')
117
    batch, _ = logits.shape
118

119
    with jax.named_scope('sample'):
120
      # logits: float32[batch, vocab.YZ] || float32[batch.X, vocab.YZ]
121

122
      if batch < z_axis:
123
        # float32[batch, vocab.YZ] -> float32[batch, vocab]
124
        # || float32[batch.X, vocab.YZ] -> float32[batch.X, vocab]
125
        logits = lax.all_gather(logits, ('y', 'z'), axis=1, tiled=True)
126
        all_gather_tokens = None if batch_unsharded else 'x'
127
      elif batch >= z_axis and batch < y_axis * z_axis:
128
        # float32[batch, vocab.YZ] -> float32[batch.Z, vocab]
129
        # || float32[batch.X, vocab.YZ] -> float32[batch.XZ, vocab]
130
        logits = lax.all_to_all(
131
            logits, 'z', split_axis=0, concat_axis=1, tiled=True
132
        )
133
        logits = lax.all_gather(logits, 'y', axis=1, tiled=True)
134
        split_size = batch // z_axis
135
        step_rngs = lax.dynamic_slice_in_dim(
136
            step_rngs, lax.axis_index('z') * split_size, (split_size), axis=0
137
        )
138
        all_gather_tokens = 'z' if batch_unsharded else ('x', 'z')
139
      elif batch >= y_axis * z_axis:
140
        # float32[batch, vocab.YZ] -> float32[batch.YZ, vocab]
141
        # || float32[batch.X, vocab.YZ] -> float32[batch.XYZ, vocab]
142
        logits = lax.all_to_all(
143
            logits, ('y', 'z'), split_axis=0, concat_axis=1, tiled=True
144
        )
145
        split_size = batch // y_axis // z_axis
146
        step_rngs = lax.dynamic_slice_in_dim(
147
            step_rngs, yz_index * split_size, (split_size), axis=0
148
        )
149
        all_gather_tokens = ('y', 'z') if batch_unsharded else ('x', 'y', 'z')
150
      else:
151
        raise NotImplementedError
152

153
      assert logits.shape[0] == step_rngs.shape[0]
154
      # TODO(sholto): Confirm this is the best way of doing it
155
      # logits = binary_search.topp_mask(logits, 0.9, -1e10)
156
      # TODO(sholto): maybe put t5x binary search back in
157
      sample_result = jnp.int32(
158
          jax.vmap(jax.random.categorical)(
159
              step_rngs, logits / hyper_params.temperature
160
          )
161
      )
162
      if all_gather_tokens is not None:
163
        # sample: int32[batch]
164
        sample_result = lax.all_gather(
165
            sample_result, all_gather_tokens, axis=0, tiled=True
166
        )
167
      return sample_result
168

169
  logit_specs = partitioning.logical_to_physical(P('logit_batch', 'vocab'))
170
  rng_specs = partitioning.logical_to_physical(P('logit_batch', None))
171
  # if it cannot be sharded as such, then do not
172
  # rng_specs = partitioning.safe_sharding(step_rngs, P(('x', 'y', 'z')), mesh)
173
  sample_result = shard_map(
174
      lowered_fn,
175
      mesh=mesh,
176
      in_specs=(logit_specs, rng_specs),
177
      out_specs=P(None),
178
      check_rep=False,
179
  )(logits, step_rngs)
180

181
  return sample_result
182

183

184
def sample_manual_batch_unsharded(
185
    logits,
186
    step_rngs,
187
    hyper_params,
188
    mesh):
189
  """Samples from output logits within xmap, with batch unshardedable.
190

191
  Args:
192
    logits: [batch, vocab.YZX]
193
    step_rngs: [batch]
194
    hyper_params: -
195
    mesh: for manual collectives
196

197
  Returns:
198
    sample" int32[batch]
199
  """
200
  def lowered_fn(logits, step_rngs):
201
    with jax.named_scope('sample'):
202
      # multi-part all gather not implemented for xmap in jit see lax.parallel
203
      logits = lax.all_gather(logits, 'x', axis=1, tiled=True)
204
      logits = lax.all_gather(logits, 'z', axis=1, tiled=True)
205
      logits = lax.all_gather(logits, 'y', axis=1, tiled=True)
206
      assert logits.shape[0] == step_rngs.shape[0]
207
      sample_result = jnp.int32(
208
          jax.vmap(jax.random.categorical)(
209
              step_rngs, logits / hyper_params.temperature
210
          )
211
      )
212
      return sample_result
213
  logit_specs = partitioning.logical_to_physical(P('logit_batch', 'vocab'))
214
  sample_result = shard_map(
215
      lowered_fn,
216
      mesh=mesh,
217
      in_specs=(logit_specs, P(None)),
218
      out_specs=P(None),
219
      check_rep=False,
220
  )(logits, step_rngs)
221
  return sample_result
222

223

224
class SampleFn(typing_extensions.Protocol):
225
  """A function providing a forwards pass through a model."""
226

227
  def __call__(
228
      self,
229
      logits,
230
      step_rngs,
231
      hyper_params,
232
      mesh,
233
  ):
234
    Ellipsis
235

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.