google-research

Форк
0
478 строк · 14.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Handy import wrappers."""
17

18
import dataclasses
19
from enum import Enum  # pylint: disable=g-importing-member
20
from functools import partial  # pylint: disable=g-importing-member
21
import logging
22
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union
23

24
from flax import struct
25
from flax.training import common_utils
26
import jax
27
import jax.numpy as jnp
28
import numpy as np
29
from seqio.vocabularies import Vocabulary
30
from t5x import losses
31
from t5x.models import DecoderOnlyModel
32

33
from scaling_transformer_inference_efficiency import checkpoint
34
from scaling_transformer_inference_efficiency import chunk
35
from scaling_transformer_inference_efficiency import incremental
36
from scaling_transformer_inference_efficiency import inference
37
from scaling_transformer_inference_efficiency import partitioning
38
from scaling_transformer_inference_efficiency import sampling
39
from scaling_transformer_inference_efficiency import weights
40
from scaling_transformer_inference_efficiency.layers import one_d_parallel_xmap
41
from scaling_transformer_inference_efficiency.layers import two_d_parallel_xmap
42

43

44
PyTree = Any
45

46

47
@struct.dataclass
48
class TestVocab:
49
  eos_id = 0
50
  bos_id = 0
51
  pad_id = 0
52

53
  def encode_tf(self, text):
54
    chars = np.array([ord(c) for c in text]).astype(np.int32)
55
    return chars
56

57
  def decode_tf(self, tokens):
58
    results = np.split(tokens, tokens.shape[0])
59
    return np.array([[chr(r) for r in list(line[0])] for line in results])
60

61

62
class Layout(Enum):
63
  TWO_D = 'two_d'
64
  ONE_D = 'one_d'
65
  WEIGHT_GATHERED = 'weight_gathered'
66

67

68
@dataclasses.dataclass
69
class ModelConfig:
70
  """An object to make gin file input elegant.
71

72
  ckpt_path: typically cns path
73
  size: 8, 62, 540
74
  quantized:
75
  generate_steps: Amount of steps to do generation with
76
  kv_cache_sharding: the degree of kv cache sharding (0: None, 1: Z, 2: YZ, 3:
77
    YZX)
78
  latency_collectives: whether to use latency optimised forms (double compute
79
    per step, half the steps for collective matmuls)
80
  batch_unsharded:  whether to shard batch dim
81
  shard_seqlen_vs_batch: whether to shard seqlen vs batch
82
  stream: An object to facilitate streaming back to X (you defined the
83
    callbacks).
84
  transpose_scan_axis: transpose if layers was not saved as the leading axis
85
  bos_id: Optionally overwrite bos_id to the model.
86
  """
87

88
  ckpt_path: str
89
  size: int
90
  quantized: bool
91
  generate_steps: int
92
  kv_cache_sharding: int
93
  latency_collectives: bool
94
  batch_unsharded: bool
95
  shard_seqlen_vs_batch: bool
96
  stream: Optional[incremental.StreamClient] = None
97
  transpose_scan_axis: bool = True
98
  layout: Layout = Layout.TWO_D
99
  bos_id: Optional[int] = None
100

101

102
def return_minimal_palm(
103
    cfg,
104
    params_already_loaded=False,
105
    remat = None,
106
    devices = None,
107
):  # pylint: disable = g-bare-generic, line-too-long
108
  """Utility function to return a model.
109

110
  Args:
111
    cfg: A model configuration
112
    params_already_loaded: whether params have been loaded yet
113
    remat: Whether to remat the layer, used for training.
114
      jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
115
      jax.checkpoint_policies.nothing_saveable
116
    devices: devices to make a mesh from
117

118
  Returns:
119
    model: A model wrapper
120
    params: The params
121
    prefill_fn: Function to pass as prefill (to ensure it is compilation cached)
122
    generate_fn: Function to pass as generation (to ensure it is compilation
123
    cached)
124
  """
125
  one_d = cfg.layout == Layout.ONE_D
126
  if cfg.shard_seqlen_vs_batch and cfg.batch_unsharded:
127
    raise NotImplementedError(
128
        "Either shard seqlen instead of batch or don't shard batch."
129
    )
130

131
  del remat  # for the moment, always remat
132
  # We have preset sizes
133
  if cfg.size == 0:
134
    hparams = checkpoint.HParams.TOY
135
  if cfg.size == 8:
136
    hparams = checkpoint.HParams.PALM_8B
137
  elif cfg.size == 62:
138
    hparams = checkpoint.HParams.PALM_62B
139
  elif cfg.size == 540:
140
    hparams = checkpoint.HParams.PALM_540B
141

142
  if cfg.quantized:
143
    ckpt = checkpoint.QuantizedCheckpoint
144
    params_spec = weights.QuantizedWeights
145
  else:
146
    ckpt = checkpoint.Checkpoint
147
    params_spec = weights.Weights
148

149
  if cfg.size == 0:
150
    loaded_ckpt = ckpt.init_zero(hparams)
151
  else:
152
    spec = checkpoint.CheckpointSpec(
153
        hparams=hparams,
154
        dir=cfg.ckpt_path,
155
        transpose_scan_axis=cfg.transpose_scan_axis,
156
    )
157
    loaded_ckpt = ckpt.load_spec(spec)
158

159
  if cfg.kv_cache_sharding == 0:
160
    attn_batch_sharding = partitioning.AttnAllToAll.NONE
161
  elif cfg.kv_cache_sharding == 1:
162
    attn_batch_sharding = partitioning.AttnAllToAll.AXIS_Z
163
  elif cfg.kv_cache_sharding == 2:
164
    attn_batch_sharding = partitioning.AttnAllToAll.AXES_YZ
165
  elif cfg.kv_cache_sharding == 3:
166
    attn_batch_sharding = partitioning.AttnAllToAll.AXES_YZX
167
  else:
168
    raise NotImplementedError
169

170
  if cfg.layout == Layout.TWO_D:
171
    rules = partitioning.make_rules_two_d(
172
        attn_batch_sharding, batch_unsharded=cfg.batch_unsharded
173
    )
174
    layer_fn = partial(
175
        two_d_parallel_xmap.transformer_layer_weight_stationary,
176
        attn_all_to_all=attn_batch_sharding,
177
        latency_collectives=cfg.latency_collectives,
178
        shard_seqlen_vs_batch=cfg.shard_seqlen_vs_batch,
179
        batch_unsharded=cfg.batch_unsharded,
180
    )
181
    # sample_fn = partial(sampling.sample_manual,
182
    # batch_unsharded=cfg.batch_unsharded)
183
    sample_fn = sampling.sample
184

185
  elif cfg.layout == Layout.ONE_D:
186
    rules = partitioning.make_rules_one_d()
187
    layer_fn = partial(
188
        one_d_parallel_xmap.weight_stationary_simple,
189
        latency_collectives=cfg.latency_collectives,
190
    )
191
    sample_fn = sampling.sample_manual_batch_unsharded
192
  elif cfg.layout == Layout.WEIGHT_GATHERED:
193
    rules = partitioning.make_rules_weight_gathered()
194
    sample_fn = sampling.sample
195
    raise NotImplementedError
196
  else:
197
    raise NotImplementedError
198

199
  if cfg.size == 0:
200
    the_vocab = TestVocab()
201
  else:
202
    the_vocab = checkpoint.load_vocab()
203

204
  mesh = partitioning.make_mesh(one_d=one_d, devices=devices)
205
  sharding_config = partitioning.ShardingConfig(
206
      mesh=mesh,
207
      attn_all_to_all=attn_batch_sharding,
208
      latency_collectives=cfg.latency_collectives,
209
      shard_seqlen_vs_batch=cfg.shard_seqlen_vs_batch,
210
      batch_unsharded=cfg.batch_unsharded,
211
  )
212

213
  embed_fn = partial(
214
      two_d_parallel_xmap.embed_manual,
215
      shard_seqlen_vs_batch=cfg.shard_seqlen_vs_batch,
216
      batch_unsharded=cfg.batch_unsharded,
217
      one_d=one_d,
218
  )
219

220
  unembed_fn = partial(
221
      two_d_parallel_xmap.unembed_manual,
222
      batch_unsharded=cfg.batch_unsharded,
223
      one_d=one_d,
224
  )
225

226
  forward_pass = partial(
227
      inference.manual_fwd_pass,
228
      hparams,
229
      sharding_config,
230
      embed_fn,
231
      layer_fn,
232
      unembed_fn,
233
  )
234

235
  infer_stack = partial(
236
      inference.infer_template,
237
      hparams,
238
      sharding_config,
239
      forward_pass,
240
  )
241

242
  model = incremental.InferenceModel(
243
      hparams,
244
      the_vocab.eos_id,
245
      infer_stack,
246
      sample_fn,
247
      mesh,
248
      rules,
249
      the_vocab,
250
      bos_id=cfg.bos_id,
251
  )
252

253
  generate_fn = model.instantiate_generating_fn(cfg.generate_steps)
254
  prefill_fn = model.instantiate_prefill_fn()
255

256
  if params_already_loaded:
257
    return model, None, prefill_fn, generate_fn
258
  else:
259
    # actually load the weights
260
    with model.mesh, model.rules:
261
      params = params_spec.from_checkpoint(hparams, model.mesh, loaded_ckpt)
262

263
    logging.info('Weights loaded.')
264

265
    # cs2 = cs.replace(hparams = cs.hparams.replace(heads=64, padded_heads=32))
266
    params = (
267
        model.rotate_weights(params, cfg.latency_collectives)
268
        if cfg.latency_collectives
269
        else params
270
    )
271
    logging.info('Weights formatted.')
272
  return model, params, prefill_fn, generate_fn
273

274

275
@jax.jit
276
def find_common_prefix(tokens):
277
  # find a common prefix
278
  base_case = tokens[0, :]
279
  is_equal = jnp.int8(tokens == base_case)  # broadcasts across the batch
280
  equal_at = jnp.prod(is_equal, axis=0)  # get a single dimensional array
281
  cp = jnp.cumprod(equal_at, 0)
282
  first_non_equal = jnp.argmin(cp)  # will get the first 0
283
  return first_non_equal
284

285

286
@jax.jit
287
def ce_loss(
288
    score_result, batch
289
):
290
  """Cross entropy loss."""
291
  token_scores = (
292
      -losses.cross_entropy_with_logits(
293
          score_result.logits,
294
          common_utils.onehot(
295
              batch['decoder_target_tokens'],
296
              score_result.logits.shape[-1],
297
              on_value=1,
298
              off_value=0,
299
          ),
300
          z_loss=0.0,
301
      )[0]
302
      * batch['decoder_loss_weights']
303
  )
304
  return token_scores
305

306

307
# pylint: disable = g-bare-generic
308
# pylint: disable = invalid-name
309
@dataclasses.dataclass
310
class InferenceT5X(DecoderOnlyModel):
311
  """Creates an API that fits T5X."""
312

313
  model: incremental.InferenceModel
314
  params: weights.Weights
315
  prefill_fn: Callable
316
  generate_fn: Callable
317
  _batch_size: int
318
  _input_vocabulary: Vocabulary
319
  _output_vocabulary: Vocabulary
320
  sample_ids: jax.Array
321
  max_input_length: int
322
  max_generate_length: int
323

324
  def __init__(
325
      self,
326
      cfg,
327
      _input_vocabulary,
328
      batch_size,
329
      task_feature_lengths,
330
  ):
331
    model, params, prefill_fn, generate_fn = return_minimal_palm(cfg)  # pylint: disable = unbalanced-tuple-unpacking
332
    self.model = model
333
    self.params = params
334
    self.prefill_fn = prefill_fn
335
    self.generate_fn = generate_fn
336
    self.get_logits_fn = model.instantiate_prefill_fn(return_full_chunk=True)
337
    self._batch_size = batch_size
338
    self._input_vocabulary = _input_vocabulary
339
    self._output_vocabulary = _input_vocabulary
340
    self.max_input_length = task_feature_lengths['inputs']
341
    self.max_generate_length = task_feature_lengths['targets']
342

343
    # make a custom model for the common_prefix / prefill sections
344
    # this is only function defs not params
345
    prefix_model_cfg = dataclasses.replace(
346
        cfg, kv_cache_sharding=0, batch_unsharded=True
347
    )
348
    prefix_model, _, prefix_prefill_fn, _ = return_minimal_palm(
349
        prefix_model_cfg, params_already_loaded=True
350
    )
351
    self.prefix_model = prefix_model
352
    self.prefix_prefill_fn = prefix_prefill_fn
353

354
  def predict_batch(self, params, batch):
355
    """Does an inference step.
356

357
    Args:
358
      params: Pytree definition of weights
359
      batch: assumed to have fields {'decoder_causal_attention': int [batch,
360
        length], 'decoder_input_tokens': same}
361

362
    Returns:
363
      inferences: (output.tokens, {'scores': output_result.per_token_scores})
364
      tokens is either [batch, tokens] or [batch, num_decodes, tokens]
365
    """
366

367
    return self.predict_batch_with_aux(params, batch)
368

369
  def predict_batch_with_aux(
370
      self,
371
      params,
372
      batch,
373
      rng = None,
374
      num_decodes = 1,
375
      temperature = 0.7,
376
      return_all_decodes = True,
377
      decoder_params=None,
378
  ):
379
    with jax.named_scope('make_batch'):
380
      prefix, prompt = self.make_batch(batch)
381
    processed_cache = self.process_cache(params, prompt, prefix)
382
    with jax.named_scope('generate'):
383
      sample_hyperparams = sampling.SamplingHyperParams(temperature=temperature)
384
      sample_ids = np.arange(self._batch_size * num_decodes)
385
      output, output_result = self.model.generate(
386
          params,
387
          self.generate_fn,
388
          processed_cache,
389
          sample_ids,
390
          sample_hyperparams,
391
      )
392

393
    if num_decodes > 1:
394
      tokens = output.tokens.reshape((self._batch_size, num_decodes, -1))
395
      scores = output_result.per_token_scores.sum(-1).reshape(
396
          (self._batch_size, num_decodes)
397
      )
398
    else:
399
      tokens = output.tokens
400
      scores = output_result.per_token_scores.sum(-1)
401

402
    inferences = tokens, {
403
        'scores': scores
404
    }  # none in place of scores for the moment
405

406
    return inferences
407

408
  def score_batch(
409
      self,
410
      params,
411
      batch,
412
      return_intermediates = False,
413
  ):
414
    inputs_lengths = np.sum(batch['decoder_causal_attention'], axis=1) - 1
415
    masked_inputs = (
416
        batch['decoder_input_tokens'] * batch['decoder_causal_attention']
417
    )
418
    score_chunk = chunk.Chunk(masked_inputs, inputs_lengths)  # [batch, time]
419

420
    # TODO(sholto): We could play the common prefix trick here too
421
    score_result = self.model.prefill(
422
        self.params, self.get_logits_fn, [], score_chunk
423
    )
424
    # TODO(sholto): Test if manual version made for cascades uses less memory
425
    token_scores = ce_loss(score_result, batch)
426
    sequence_scores = token_scores.sum(-1)
427
    return sequence_scores
428

429
  def make_batch(
430
      self,
431
      batch,
432
      extract_prefix = False,
433
      common_prefix_heuristic = 32,
434
  ):
435
    inputs_lengths = np.sum(batch['decoder_causal_attention'], axis=1) - 1
436
    masked_inputs = (
437
        batch['decoder_input_tokens'] * batch['decoder_causal_attention']
438
    )
439
    inputs = masked_inputs[:, : self.max_input_length]  # [batch, time]
440

441
    if extract_prefix:
442
      # NB: the below is not jax jittable.
443
      common_prefix = find_common_prefix(inputs)  # integer
444
      # Heuristic for whether prefix extraction is worth doing
445
      if (common_prefix > common_prefix_heuristic) and (
446
          self.max_input_length - common_prefix_heuristic > common_prefix
447
      ):
448
        logging.info('Detected common prefix of length %i', common_prefix)
449
        prefix = chunk.Chunk(
450
            jnp.expand_dims(inputs[0, :common_prefix], 0),
451
            jnp.array([common_prefix]),
452
        )
453
        prompt = chunk.Chunk(
454
            inputs[:, common_prefix:], inputs_lengths - common_prefix
455
        )
456
        return prefix, prompt
457
    # Default to no prefix extraction
458
    prompt = chunk.Chunk(inputs, inputs_lengths)
459
    prefix = None
460
    return prefix, prompt
461

462
  def process_cache(
463
      self, params, prompt, prefix=None
464
  ):
465
    processed_cache = []
466
    if prefix is not None:
467
      with jax.named_scope('common_prefill'):
468
        # the common prefix will be batch size 1, shard appropriately
469
        common_prefix = self.prefix_model.prefill(
470
            params, self.prefix_prefill_fn, [], prefix
471
        )
472
        processed_cache.append(common_prefix)
473
    with jax.named_scope('different_prefill'):
474
      prompt = self.model.prefill(
475
          params, self.prefill_fn, processed_cache, prompt
476
      )
477
      processed_cache.append(prompt)
478
    return processed_cache
479

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.