google-research

Форк
0
783 строки · 23.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Common functions for setting up inputs for adversarial prefix learning."""
17

18
import collections
19
import functools
20
import string
21
from typing import Any, List, Optional, Tuple, TypedDict
22

23
from flax import linen as nn
24
import jax
25
import jax.numpy as jnp
26
import numpy as np
27
import optax
28
import pandas as pd
29
from paxml import trainer_lib
30
from praxis import base_layer
31
from praxis import py_utils
32
from praxis import pytypes
33
import seqio
34
from tensorflow_probability.substrates import jax as tfp
35

36
# Define aliases for brevity
37
NestedMap = py_utils.NestedMap
38
JTensor = pytypes.JTensor
39
RANDOM = base_layer.RANDOM
40
DECODE_CACHE = base_layer.DECODE_CACHE
41

42

43
def calc_max_onehot(x):
44
  return jax.nn.one_hot(jnp.argmax(x, -1), x.shape[-1], dtype=x.dtype)
45

46

47
class GumbelSoftmaxParams(TypedDict):
48
  temp: float
49
  hard: bool
50

51

52
class WrappedModel:
53
  """Wrapper for a Pax model that uses identity-based equality comparisons.
54

55
  This makes it possible to pass the model into jax functions such as `jit`,
56
  even if it is not hashable.  Note: jax will retrace jit-compiled functions
57
  whenever they are called with a new instance of the wrapped model.
58
  """
59

60
  def __init__(self, model):
61
    self.model = model
62

63
  def __eq__(self, other):
64
    return isinstance(other, WrappedModel) and self.model is other.model
65

66
  def __hash__(self):
67
    return id(self.model)
68

69

70
@functools.partial(jax.vmap, in_axes=(0, None, 0))
71
def _gumbel_softmax_part(
72
    logits, temp, rng
73
):
74
  # Temp must be >0 or we will get NaNs.
75
  # Can't use if statement to validate input or this won't jit well.
76
  converted_logits = jnp.array(logits, dtype=jnp.float32)
77
  dist = tfp.distributions.RelaxedOneHotCategorical(
78
      temp, logits=converted_logits
79
  )
80
  return jnp.array(dist.sample(seed=rng), dtype=logits.dtype)
81

82

83
def _gumbel_softmax_batch_keys(inputs, temp, hard,
84
                               all_rngs):
85
  """Helper function for gumbel_softmax()."""
86

87
  def flatten(x):
88
    return jnp.reshape(x, (-1, x.shape[-1]))
89

90
  flat_inputs = flatten(inputs)
91
  all_rngs = flatten(all_rngs)
92

93
  y = _gumbel_softmax_part(flat_inputs, temp, all_rngs)
94

95
  def _hard_fn():
96
    y_hard = calc_max_onehot(y)
97
    return jax.lax.stop_gradient(y_hard - y) + y
98

99
  result = jax.lax.cond(hard, _hard_fn, lambda: y)
100

101
  return jnp.reshape(result, inputs.shape)
102

103

104
def gumbel_softmax(
105
    inputs, temp, hard, rng
106
):
107
  """Draws from the gumbel softmax distribution over the given inputs.
108

109
  Args:
110
    inputs: Array
111
    temp: temperature of the gumbel softmax
112
    hard: If true, sample one-hot vector. Else return logits
113
    rng: A single random key for the operation
114

115
  Returns:
116
    Samples from a gumbel softmax distribution for each set of logits in inputs
117
  """
118
  all_rngs = jax.random.split(rng, np.prod(inputs.shape[:-1]))
119
  return _gumbel_softmax_batch_keys(inputs, temp, hard, all_rngs)
120

121

122
def smooth_logits(
123
    tokens, smooth_factor, logits_dim, dtype
124
):
125
  onehot = jax.nn.one_hot(tokens, logits_dim, dtype=dtype)
126
  smoothed_onehot = onehot * (1 - 2 * smooth_factor) + smooth_factor
127
  return jnp.log(smoothed_onehot / jnp.sum(smoothed_onehot, -1, keepdims=True))
128

129

130
def replicate_batch(x, batch_size):
131
  return jnp.array([x] * batch_size)
132

133

134
def replicate_batch_tree(tree, batch_size):
135
  return jax.tree_map(
136
      functools.partial(replicate_batch, batch_size=batch_size), tree)
137

138

139
def contains_only(vocab_string, chars):
140
  """Checks if the string only contains only the listed chars."""
141
  return all(c in chars for c in vocab_string)
142

143

144
def keep_alphanumeric_punct(
145
    index,
146
    vocabulary,
147
    exclude_no_space,
148
):
149
  """Returns True if the string contains ascii chars or punct, but not both."""
150
  if index == 1005:
151
    # Keep space.
152
    return True
153

154
  vocab_string = vocabulary.decode([index])
155

156
  if not vocab_string:
157
    return False
158

159
  alphanum_chars = string.ascii_letters + string.digits
160

161
  if contains_only(vocab_string, alphanum_chars):
162
    if exclude_no_space:
163
      return (len(vocab_string) == 1) or (
164
          # Add a non functional token to the beginning
165
          # to detect if there's a space.
166
          ' '
167
          in vocabulary.decode([1011, index])
168
      )
169
    else:
170
      return True
171

172
  return contains_only(vocab_string, string.punctuation)
173

174

175
def keep_vocab(
176
    index,
177
    vocabulary,
178
    exclude_tokens,
179
    exclude_no_space,
180
):
181
  return keep_alphanumeric_punct(index, vocabulary, exclude_no_space) and (
182
      vocabulary.decode([index]) not in exclude_tokens
183
  )
184

185

186
def get_vocab_mask(
187
    vocabulary,
188
    exclude_tokens,
189
    exclude_no_space,
190
):
191
  """Masks out tokens where keep_alphanumeric_punct returns false."""
192
  mask = jnp.array([
193
      keep_vocab(i, vocabulary, exclude_tokens, exclude_no_space)
194
      for i in range(vocabulary.vocab_size)
195
  ])
196

197
  return mask
198

199

200
def make_inputs(
201
    prefix,
202
    input_len,
203
    decode_len,
204
    input_for_classify,
205
    vocabulary,
206
    vocab_mask,
207
    dtype,
208
):
209
  """Returns the tokens, logits and parameters for the input to the model.
210

211
  Args:
212
    prefix: The input text before the adversarial input.
213
    input_len: The length of the adversarial input.
214
    decode_len: The length of the decode output.
215
    input_for_classify: The classifier uses this input instead of the
216
      adversarial input. Set to None to pass the adversarial input to the
217
      classifier.
218
    vocabulary: The vocabulary
219
    vocab_mask: Boolean mask of whether to keep or discard each vocab index.
220
    dtype: The type of the input to generate.
221

222
  Returns:
223
    tokens: A dict of tokens. These are assembled in loss_fn to get the full
224
    input to the model.
225
      prefix: Tokens which go before the adversarial input.
226
      sep: The tokens between the adversarial input and the model output.
227
      decode: The placeholder tokens for the model output.
228
      input_for_classify: The classifier input to use instead of the
229
      adversarial input.
230
      classify: The tokens which tell the model to classify.
231
    logits: The adversarial logits to optimize.
232
    vocab_mask: Same as the input.
233
    label: The label token.
234
  """
235
  if input_for_classify is not None:
236
    input_for_classify_tokens = jnp.array(
237
        [0] + list(vocabulary.encode(input_for_classify)))
238
  else:
239
    input_for_classify_tokens = jnp.array([], dtype=jnp.int32)
240

241
  return NestedMap.FromNestedDict(
242
      dict(
243
          tokens=dict(
244
              prefix=jnp.array([0] + list(vocabulary.encode(prefix))),
245
              sep=jnp.array(vocabulary.encode('SEP')),
246
              decode=jnp.array([-1] * decode_len),
247
              input_for_classify=input_for_classify_tokens,
248
              classify=jnp.array(vocabulary.encode('SAFETY ')),
249
          ),
250
          logits=smooth_logits(
251
              [0] * input_len, 0.5, len(vocab_mask), dtype=dtype
252
          ),  # pytype: disable=wrong-arg-types  # jax-ndarray
253
          vocab_mask=vocab_mask,
254
          label=vocabulary.encode('1')[-1],
255
      )
256
  )
257

258

259
def simple_decode(
260
    mdl_vars,
261
    input_onehot,
262
    use_decoded_mask,
263
    prefix_len,
264
    gs_params,
265
    prng_key,
266
    greedy,
267
    wrapped_model,
268
):
269
  # pylint: disable-next=g-doc-args
270
  # pylint: disable-next=g-doc-return-or-yield
271
  """Calls simple_decode() of the given model.
272

273
  We pass the model in a wrapper because it is unhashable so cannot be jitted.
274
  """
275
  full_mdl_vars = mdl_vars.copy()
276
  full_mdl_vars['params'] = {
277
      'embedding_model': mdl_vars['params']['lm']['softmax'],
278
      'model': mdl_vars['params']
279
  }
280

281
  with base_layer.JaxContext.new_context(
282
      hparams=base_layer.JaxContext.HParams(do_eval=True)):
283
    outputs, _ = wrapped_model.model.apply(
284
        full_mdl_vars,
285
        input_onehot,
286
        use_decoded_mask,
287
        prefix_len,
288
        gs_params,
289
        greedy,
290
        method=wrapped_model.model.simple_decode,
291
        rngs={RANDOM: prng_key},
292
        mutable=[DECODE_CACHE],
293
    )
294
    return outputs
295

296

297
def logits_loss(output_logits, final_label):
298
  # How closely the final logit matches the label.
299
  output_probs = nn.softmax(output_logits[:, -1], axis=-1)
300
  return -output_probs[:, final_label]
301

302

303
def filter_onehot(input_onehot, vocab_mask):
304
  # This function only works with positive inputs.
305
  input_onehot = jnp.where(vocab_mask[None, None, :], input_onehot, 0.0)
306
  return input_onehot / jnp.sum(input_onehot, axis=-1, keepdims=True)
307

308

309
def difference_loss(
310
    input_logits,
311
    input_onehot,
312
    output_logits,
313
    vocab_mask,
314
):
315
  """How much do the updated token probabilities differ from the predicted ones.
316
  """
317
  del input_onehot
318

319
  normalized_output_logits = jax.nn.log_softmax(output_logits, axis=-1)
320
  input_probability = jax.nn.softmax(input_logits, axis=-1)
321
  input_probability = filter_onehot(input_probability, vocab_mask)
322

323
  result_logits = input_probability * normalized_output_logits
324
  result_logits = jnp.sum(result_logits, -1)
325

326
  # Divide by a constant to scale the loss.
327
  return -result_logits / 30.0
328

329

330
def loss_fn(
331
    input_logits,
332
    mdl_vars,
333
    full_input,
334
    prng_key,
335
    input_gs_params,
336
    decode_gs_params,
337
    difference_loss_weight,
338
    batch_size,
339
    is_eval,
340
    wrapped_model,
341
    return_aux_tokens,
342
):
343
  """Computes predicted output label and compares against true label.
344

345
  The input to simple_decode is constructed from full_input.tokens, and
346
  input_logits.
347
  If there are tokens in input_for_classify, the input is
348
    prefix, input_logits, sep, decode. The decode tokens are placeholders.
349
    simple_decode uses this input to get the output tokens.
350
    The next input is input_for_classify, sep, decode_output, classify.
351
    simple_decode uses this input to get the classifier score.
352
    Note that only input_for_classify is given when calculating the classifier
353
    score, not input_logits.
354

355
  If there are no input_for_classify tokens, the input is
356
    prefix, input_logits, sep, decode, classify
357
    simple_decode gets the decoded model output, then adds the classify tokens
358
    to get the classify score.
359

360
  Args:
361
    input_logits: The adversarial logits. The gradient is computed with respect
362
      to these.
363
    mdl_vars: Model vars
364
    full_input: See make_inputs. Use input_logits instead of the logits from
365
      here.
366
    prng_key: Rand key.
367
    input_gs_params: The temp and hard for the gumbel softmax on the input.
368
    decode_gs_params: The temp and hard for the gumbel softmax during decode.
369
    difference_loss_weight: How much optimize the probability of the input.
370
    batch_size: batch_size
371
    is_eval: True to use maximum instead of sampling with gumbel softmax.
372
    wrapped_model: Pax model.
373
    return_aux_tokens: True to return additional tokens for logging.
374

375
  Returns:
376
    total_loss: weighted sum of adversarial and difference losses
377
    losses: Map of loss and difference_loss
378
    aux_tokens: Returned if return_aux_tokens.
379
  """
380
  input_logits_batch = replicate_batch(input_logits, batch_size)
381

382
  if is_eval:
383
    # Take the softmax because
384
    # construct_decode_input only works with positive inputs.
385
    input_onehot = jax.nn.softmax(input_logits_batch, axis=-1)
386
    assert batch_size == 1
387
  else:
388
    prng_key, gs_prng_key = jax.random.split(prng_key)
389
    input_onehot = gumbel_softmax(input_logits_batch, input_gs_params['temp'],
390
                                  input_gs_params['hard'], gs_prng_key)
391

392
  input_onehot = filter_onehot(input_onehot, full_input.vocab_mask)
393
  if is_eval:
394
    input_onehot = calc_max_onehot(input_onehot)
395

396
  # Converts the tokens to onehot representation.
397
  onehot_tokens = {}
398
  for token_type, tokens in full_input.tokens.items():
399
    onehot_tokens[token_type] = replicate_batch(
400
        jax.nn.one_hot(
401
            tokens, input_onehot.shape[-1], dtype=input_onehot.dtype
402
        ),
403
        input_onehot.shape[0],
404
    )
405

406
  # Construct the input.
407
  decode_prefix_onehot = jnp.concatenate(
408
      [onehot_tokens['prefix'], input_onehot, onehot_tokens['sep']], axis=1)
409
  decode_prefix_len = decode_prefix_onehot.shape[1]
410
  use_decoded_mask = jnp.concatenate([
411
      jnp.zeros(decode_prefix_onehot.shape[1]),
412
      jnp.ones(onehot_tokens['decode'].shape[1])
413
  ],
414
                                     axis=0)
415
  decode_input_onehot = jnp.concatenate(
416
      [decode_prefix_onehot, onehot_tokens['decode']], axis=1)
417

418
  has_input_for_classify = full_input.tokens.input_for_classify.shape[0] > 0
419

420
  if not has_input_for_classify:
421
    # Add the classify tokens.
422
    decode_input_onehot = jnp.concatenate(
423
        [decode_input_onehot, onehot_tokens['classify']], axis=1)
424
    use_decoded_mask = jnp.concatenate(
425
        [use_decoded_mask,
426
         jnp.zeros(onehot_tokens['classify'].shape[1])],
427
        axis=0)
428

429
  prng_key, decode_key = jax.random.split(prng_key)
430
  decode_out = simple_decode(  # pytype: disable=wrong-arg-types  # jnp-type
431
      mdl_vars,
432
      decode_input_onehot,
433
      use_decoded_mask,
434
      decode_prefix_len,
435
      decode_gs_params,
436
      decode_key,
437
      greedy=is_eval,
438
      wrapped_model=wrapped_model,
439
  )
440

441
  # Calc difference_loss
442
  update_start = full_input.tokens['prefix'].shape[0] - 1
443
  # Axis 1 not 0 because input_onehot has been replicated into multiple batches.
444
  update_end = update_start + input_onehot.shape[1]
445
  updated_decoded_logits = decode_out['logits'][:, update_start:update_end, :]
446
  difference_loss_value = jnp.mean(
447
      difference_loss(
448
          input_logits_batch,
449
          input_onehot,
450
          updated_decoded_logits,
451
          full_input.vocab_mask,
452
      )
453
  )
454

455
  if has_input_for_classify:
456
    # Construct the classifier input.
457
    classify_input_onehot = jnp.concatenate([
458
        onehot_tokens['input_for_classify'], onehot_tokens['sep'],
459
        decode_out['onehot'], onehot_tokens['classify']
460
    ],
461
                                            axis=1)
462
    decode_mask_classify = jnp.zeros(classify_input_onehot.shape[1])
463
    classify_prefix_len = len(decode_mask_classify)
464

465
    prng_key, decode_key = jax.random.split(prng_key)
466
    classify_out = simple_decode(  # pytype: disable=wrong-arg-types  # jnp-array
467
        mdl_vars,
468
        classify_input_onehot,
469
        decode_mask_classify,
470
        classify_prefix_len,
471
        decode_gs_params,
472
        decode_key,
473
        greedy=is_eval,
474
        wrapped_model=wrapped_model,
475
    )
476

477
    loss = jnp.mean(logits_loss(classify_out['logits'], full_input.label))
478

479
  else:
480
    loss = jnp.mean(logits_loss(decode_out['logits'], full_input.label))
481

482
  total_loss = loss + difference_loss_value * difference_loss_weight
483
  losses = {'loss': loss, 'difference_loss': difference_loss_value}
484

485
  if return_aux_tokens:
486
    aux_tokens = {
487
        'decode_prefix': jnp.argmax(decode_prefix_onehot, -1),
488
        'decode_input': jnp.argmax(decode_input_onehot, -1),
489
        'decode_out_onehot': jnp.argmax(decode_out['onehot'], -1),
490
        'decode_out_logits': jnp.argmax(decode_out['logits'], -1),
491
    }
492
    if has_input_for_classify:
493
      aux_tokens['classify_input'] = jnp.argmax(classify_input_onehot, -1)
494
    return total_loss, losses, aux_tokens
495
  else:
496
    return total_loss, losses
497

498

499
loss_fn_jit = jax.jit(
500
    loss_fn,
501
    static_argnames=[
502
        'wrapped_model',
503
        'batch_size',
504
        'is_eval',
505
        'return_aux_tokens',
506
    ],
507
)
508

509

510
def loss_grad(
511
    full_input,
512
    model_states,
513
    prng_key,
514
    input_gs_params,
515
    decode_gs_params,
516
    difference_loss_weight,
517
    batch_size,
518
    wrapped_model,
519
):
520
  """Returns the loss, and gradient of loss_fn."""
521
  (_, loss), grad = jax.value_and_grad(loss_fn, has_aux=True)(
522
      full_input.logits,
523
      model_states.mdl_vars,
524
      full_input,
525
      prng_key,
526
      input_gs_params,
527
      decode_gs_params,
528
      difference_loss_weight,
529
      batch_size,
530
      is_eval=False,
531
      wrapped_model=wrapped_model,
532
      return_aux_tokens=False,
533
  )
534
  return loss, grad
535

536

537
@functools.partial(
538
    jax.pmap,
539
    in_axes=(0, 0, 0, None, None, None, None, None, None),
540
    static_broadcasted_argnums=[7, 8],
541
    axis_name='batch')
542
# Arguments must be passed by position not keyword because of pmap
543
def update_input_rep_par(
544
    full_input,
545
    model_states,
546
    prng_key,
547
    lr,
548
    input_gs_params,
549
    decode_gs_params,
550
    difference_loss_weight,
551
    local_batch_size,
552
    wrapped_model,
553
):
554
  """Updates the input logits to minimize the loss."""
555
  prng_key, loss_rng = jax.random.split(prng_key)
556

557
  loss, grad = loss_grad(
558
      full_input,
559
      model_states,
560
      loss_rng,
561
      input_gs_params,
562
      decode_gs_params,
563
      difference_loss_weight,
564
      local_batch_size,
565
      wrapped_model,
566
  )
567

568
  grad = jax.lax.pmean(grad, axis_name='batch')
569
  loss = jax.lax.pmean(loss, axis_name='batch')
570

571
  optimizer = optax.adam(lr)
572
  updates, opt_state = optimizer.update(grad, model_states.opt_states)
573
  input_logits = optax.apply_updates(full_input.logits, updates)
574

575
  return input_logits, opt_state, loss, prng_key
576

577

578
def eval_label_prob(
579
    full_input,
580
    model_states,
581
    verbose,
582
    vocabulary,
583
    wrapped_model,
584
):
585
  """Computes the probability of the label after the decoding step.
586

587
  Uses the token with the highest probability for the input, and during
588
  decoding. It doesn't use gumbel softmax.
589

590
  Args:
591
    full_input: The result of make_inputs. Used for the logits, and vocab_mask.
592
    model_states: For the model.
593
    verbose: Prints the inputs and outputs if True.
594
    vocabulary: Model vocab.
595
    wrapped_model: Task including the model.
596

597
  Returns:
598
    The probability of the label after the decoding.
599
    The full tokens that are used to calculate the score. This is the
600
    adversarial input followed by the separator, followed by the decoded output.
601
  """
602
  _, losses, aux_tokens = loss_fn_jit(
603
      full_input.logits,
604
      model_states.mdl_vars,
605
      full_input,
606
      jax.random.PRNGKey(0),
607
      input_gs_params=None,
608
      decode_gs_params=None,
609
      difference_loss_weight=0.0,
610
      batch_size=1,
611
      is_eval=True,
612
      wrapped_model=wrapped_model,
613
      return_aux_tokens=True,
614
  )
615
  loss = losses['loss']
616

617
  if verbose:
618
    display_dict = collections.defaultdict(list)
619

620
    for key, tokens in aux_tokens.items():
621
      display_dict[key] = _display_tokens(tokens[0], vocabulary)
622

623
    display_dict = {k: pd.Series(v) for k, v in display_dict.items()}
624

625
    print(pd.DataFrame(display_dict).to_string())
626

627
  decode_out_tokens = aux_tokens['decode_out_onehot'][
628
      0, :len(full_input.tokens['decode']) + 1]
629
  decode_input_output = jnp.concatenate(
630
      [aux_tokens['decode_prefix'][0], decode_out_tokens])
631

632
  return -loss, losses['difference_loss'], decode_input_output
633

634

635
def make_model_input(tokens):
636
  num_tokens = len(tokens)
637
  return NestedMap.FromNestedDict({
638
      'ids': tokens,
639
      'labels': np.zeros([num_tokens], np.int32),
640
      'paddings': np.zeros([num_tokens], np.int32),
641
      'weights': np.zeros([num_tokens], np.int32),
642
      'segment_ids': None,
643
      'segment_pos': None,
644
  })
645

646

647
@functools.partial(jax.jit, static_argnames='wrapped_model')
648
def regular_decode(model_states, input_tokens, wrapped_model):
649
  """Decodes using the built in PAX decoding."""
650
  model_input = replicate_batch_tree(make_model_input(input_tokens), 1)
651
  var_weight_hparams = wrapped_model.model.abstract_init_with_metadata(
652
      model_input
653
  )
654
  (_, per_example_out, _), _ = trainer_lib.decode_step(
655
      wrapped_model.model,
656
      model_states.to_eval_state(),
657
      jax.random.PRNGKey(1234),
658
      var_weight_hparams,
659
      model_input,
660
      fprop_dtype=wrapped_model.model.fprop_dtype,
661
  )
662
  return per_example_out
663

664

665
def dec_enc(tokens,
666
            vocabulary):
667
  """Decodes the tokens with the vocab then encodes them again.
668

669
  It will usually give the same result. This is used to make sure the input is
670
  tokens which are possible.
671

672
  Args:
673
   tokens: tokens
674
   vocabulary: vocabulary
675

676
  Returns:
677
    The tokens after decoding then encoding.
678
  """
679
  return jnp.array([0] + list(vocabulary.encode(vocabulary.decode(tokens))))
680

681

682
def filter_after_eos(tokens):
683
  """Removes all tokens after the eos token (eos token has value 1)."""
684
  # The 0th element of where is the first occurrence.
685
  index = jnp.where(tokens == 1)[0]
686
  if index.shape[0] == 0:
687
    return tokens
688
  return tokens[:index[0]]
689

690

691
def eval_label_prob_reg_decode(
692
    full_input,
693
    model_states,
694
    verbose,
695
    vocabulary,
696
    wrapped_model,
697
    use_dec_enc,
698
):
699
  """Similar to eval_label_prob, but it uses the normal PAX decode algorithm.
700

701
  Uses tokens instead of a one hot encoding.
702
  Usually it has the same results as eval_label_prob.
703
  But sometimes the results are different due to floating point errors.
704

705
  Args:
706
    full_input: The result of make_inputs. Used for the logits, and vocab_mask.
707
    model_states: For the model.
708
    verbose: Prints the inputs and outputs if True.
709
    vocabulary: Model vocab.
710
    wrapped_model: Task including the model.
711
    use_dec_enc: True to apply the dec_enc to the input.
712

713
  Returns:
714
    The probability of the label after the decoding.
715
    The full tokens that are used to calculate the score. This is the
716
    adversarial input followed by the separator, followed by the decoded output.
717
  """
718
  input_onehot = jax.nn.softmax(full_input.logits, axis=-1)
719
  input_onehot = filter_onehot(input_onehot, full_input.vocab_mask)[0]
720
  input_tokens = jnp.argmax(input_onehot, -1)
721

722
  if use_dec_enc:
723
    input_tokens = dec_enc(input_tokens, vocabulary)
724
  full_input_tokens = jnp.concatenate(
725
      [full_input.tokens.prefix, input_tokens, full_input.tokens.sep], 0)
726
  if use_dec_enc:
727
    full_input_tokens = dec_enc(full_input_tokens, vocabulary)
728

729
  decode_out = regular_decode(model_states, full_input_tokens, wrapped_model)
730

731
  decode_end_i = full_input_tokens.shape[0] + full_input.tokens.decode.shape[0]
732
  if full_input.tokens.input_for_classify.shape[0] > 0:
733
    decode_end_i += 1
734
  decoded_tokens = decode_out['output_ids'][
735
      0, 0, full_input_tokens.shape[0]:decode_end_i]
736
  decoded_tokens = filter_after_eos(decoded_tokens)
737

738
  full_decode_tokens = jnp.concatenate([full_input_tokens, decoded_tokens], 0)
739

740
  if full_input.tokens.input_for_classify.shape[0] > 0:
741
    safety_classifier_input = jnp.concatenate([
742
        full_input.tokens.input_for_classify, full_input.tokens.sep,
743
        decoded_tokens, full_input.tokens.classify
744
    ], 0)
745
  else:
746
    safety_classifier_input = jnp.concatenate(
747
        [full_decode_tokens, full_input.tokens.classify], 0)
748

749
  safety_out = regular_decode(
750
      model_states, safety_classifier_input, wrapped_model
751
  )
752

753
  safety_token = safety_out['output_ids'][0, 0,
754
                                          safety_classifier_input.shape[0]]
755
  safety_token_prob = jnp.exp(
756
      safety_out['logprobs'][0, 0, safety_classifier_input.shape[0]])
757

758
  if verbose:
759
    display_dict = collections.defaultdict(list)
760

761
    display_dict['decoded_tokens'] = _display_tokens(full_decode_tokens,
762
                                                     vocabulary)
763
    display_dict['safety_out'] = _display_tokens(
764
        safety_out['output_ids'][0, 0, :safety_classifier_input.shape[0] + 1],
765
        vocabulary)
766

767
    display_dict = {k: pd.Series(v) for k, v in display_dict.items()}
768

769
    print(pd.DataFrame(display_dict).to_string())
770

771
  if safety_token == full_input.label:
772
    return safety_token_prob, full_decode_tokens
773
  else:
774
    return 1 - safety_token_prob, full_decode_tokens
775

776

777
def _display_tokens(
778
    tokens,
779
    vocabulary):
780
  display_vals = []
781
  for token in tokens:
782
    display_vals.append(vocabulary.decode([token]))
783
  return display_vals
784

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.