google-research
783 строки · 23.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Common functions for setting up inputs for adversarial prefix learning."""
17
18import collections19import functools20import string21from typing import Any, List, Optional, Tuple, TypedDict22
23from flax import linen as nn24import jax25import jax.numpy as jnp26import numpy as np27import optax28import pandas as pd29from paxml import trainer_lib30from praxis import base_layer31from praxis import py_utils32from praxis import pytypes33import seqio34from tensorflow_probability.substrates import jax as tfp35
36# Define aliases for brevity
37NestedMap = py_utils.NestedMap38JTensor = pytypes.JTensor39RANDOM = base_layer.RANDOM40DECODE_CACHE = base_layer.DECODE_CACHE41
42
43def calc_max_onehot(x):44return jax.nn.one_hot(jnp.argmax(x, -1), x.shape[-1], dtype=x.dtype)45
46
47class GumbelSoftmaxParams(TypedDict):48temp: float49hard: bool50
51
52class WrappedModel:53"""Wrapper for a Pax model that uses identity-based equality comparisons.54
55This makes it possible to pass the model into jax functions such as `jit`,
56even if it is not hashable. Note: jax will retrace jit-compiled functions
57whenever they are called with a new instance of the wrapped model.
58"""
59
60def __init__(self, model):61self.model = model62
63def __eq__(self, other):64return isinstance(other, WrappedModel) and self.model is other.model65
66def __hash__(self):67return id(self.model)68
69
70@functools.partial(jax.vmap, in_axes=(0, None, 0))71def _gumbel_softmax_part(72logits, temp, rng73):74# Temp must be >0 or we will get NaNs.75# Can't use if statement to validate input or this won't jit well.76converted_logits = jnp.array(logits, dtype=jnp.float32)77dist = tfp.distributions.RelaxedOneHotCategorical(78temp, logits=converted_logits79)80return jnp.array(dist.sample(seed=rng), dtype=logits.dtype)81
82
83def _gumbel_softmax_batch_keys(inputs, temp, hard,84all_rngs):85"""Helper function for gumbel_softmax()."""86
87def flatten(x):88return jnp.reshape(x, (-1, x.shape[-1]))89
90flat_inputs = flatten(inputs)91all_rngs = flatten(all_rngs)92
93y = _gumbel_softmax_part(flat_inputs, temp, all_rngs)94
95def _hard_fn():96y_hard = calc_max_onehot(y)97return jax.lax.stop_gradient(y_hard - y) + y98
99result = jax.lax.cond(hard, _hard_fn, lambda: y)100
101return jnp.reshape(result, inputs.shape)102
103
104def gumbel_softmax(105inputs, temp, hard, rng106):107"""Draws from the gumbel softmax distribution over the given inputs.108
109Args:
110inputs: Array
111temp: temperature of the gumbel softmax
112hard: If true, sample one-hot vector. Else return logits
113rng: A single random key for the operation
114
115Returns:
116Samples from a gumbel softmax distribution for each set of logits in inputs
117"""
118all_rngs = jax.random.split(rng, np.prod(inputs.shape[:-1]))119return _gumbel_softmax_batch_keys(inputs, temp, hard, all_rngs)120
121
122def smooth_logits(123tokens, smooth_factor, logits_dim, dtype124):125onehot = jax.nn.one_hot(tokens, logits_dim, dtype=dtype)126smoothed_onehot = onehot * (1 - 2 * smooth_factor) + smooth_factor127return jnp.log(smoothed_onehot / jnp.sum(smoothed_onehot, -1, keepdims=True))128
129
130def replicate_batch(x, batch_size):131return jnp.array([x] * batch_size)132
133
134def replicate_batch_tree(tree, batch_size):135return jax.tree_map(136functools.partial(replicate_batch, batch_size=batch_size), tree)137
138
139def contains_only(vocab_string, chars):140"""Checks if the string only contains only the listed chars."""141return all(c in chars for c in vocab_string)142
143
144def keep_alphanumeric_punct(145index,146vocabulary,147exclude_no_space,148):149"""Returns True if the string contains ascii chars or punct, but not both."""150if index == 1005:151# Keep space.152return True153
154vocab_string = vocabulary.decode([index])155
156if not vocab_string:157return False158
159alphanum_chars = string.ascii_letters + string.digits160
161if contains_only(vocab_string, alphanum_chars):162if exclude_no_space:163return (len(vocab_string) == 1) or (164# Add a non functional token to the beginning165# to detect if there's a space.166' '167in vocabulary.decode([1011, index])168)169else:170return True171
172return contains_only(vocab_string, string.punctuation)173
174
175def keep_vocab(176index,177vocabulary,178exclude_tokens,179exclude_no_space,180):181return keep_alphanumeric_punct(index, vocabulary, exclude_no_space) and (182vocabulary.decode([index]) not in exclude_tokens183)184
185
186def get_vocab_mask(187vocabulary,188exclude_tokens,189exclude_no_space,190):191"""Masks out tokens where keep_alphanumeric_punct returns false."""192mask = jnp.array([193keep_vocab(i, vocabulary, exclude_tokens, exclude_no_space)194for i in range(vocabulary.vocab_size)195])196
197return mask198
199
200def make_inputs(201prefix,202input_len,203decode_len,204input_for_classify,205vocabulary,206vocab_mask,207dtype,208):209"""Returns the tokens, logits and parameters for the input to the model.210
211Args:
212prefix: The input text before the adversarial input.
213input_len: The length of the adversarial input.
214decode_len: The length of the decode output.
215input_for_classify: The classifier uses this input instead of the
216adversarial input. Set to None to pass the adversarial input to the
217classifier.
218vocabulary: The vocabulary
219vocab_mask: Boolean mask of whether to keep or discard each vocab index.
220dtype: The type of the input to generate.
221
222Returns:
223tokens: A dict of tokens. These are assembled in loss_fn to get the full
224input to the model.
225prefix: Tokens which go before the adversarial input.
226sep: The tokens between the adversarial input and the model output.
227decode: The placeholder tokens for the model output.
228input_for_classify: The classifier input to use instead of the
229adversarial input.
230classify: The tokens which tell the model to classify.
231logits: The adversarial logits to optimize.
232vocab_mask: Same as the input.
233label: The label token.
234"""
235if input_for_classify is not None:236input_for_classify_tokens = jnp.array(237[0] + list(vocabulary.encode(input_for_classify)))238else:239input_for_classify_tokens = jnp.array([], dtype=jnp.int32)240
241return NestedMap.FromNestedDict(242dict(243tokens=dict(244prefix=jnp.array([0] + list(vocabulary.encode(prefix))),245sep=jnp.array(vocabulary.encode('SEP')),246decode=jnp.array([-1] * decode_len),247input_for_classify=input_for_classify_tokens,248classify=jnp.array(vocabulary.encode('SAFETY ')),249),250logits=smooth_logits(251[0] * input_len, 0.5, len(vocab_mask), dtype=dtype252), # pytype: disable=wrong-arg-types # jax-ndarray253vocab_mask=vocab_mask,254label=vocabulary.encode('1')[-1],255)256)257
258
259def simple_decode(260mdl_vars,261input_onehot,262use_decoded_mask,263prefix_len,264gs_params,265prng_key,266greedy,267wrapped_model,268):269# pylint: disable-next=g-doc-args270# pylint: disable-next=g-doc-return-or-yield271"""Calls simple_decode() of the given model.272
273We pass the model in a wrapper because it is unhashable so cannot be jitted.
274"""
275full_mdl_vars = mdl_vars.copy()276full_mdl_vars['params'] = {277'embedding_model': mdl_vars['params']['lm']['softmax'],278'model': mdl_vars['params']279}280
281with base_layer.JaxContext.new_context(282hparams=base_layer.JaxContext.HParams(do_eval=True)):283outputs, _ = wrapped_model.model.apply(284full_mdl_vars,285input_onehot,286use_decoded_mask,287prefix_len,288gs_params,289greedy,290method=wrapped_model.model.simple_decode,291rngs={RANDOM: prng_key},292mutable=[DECODE_CACHE],293)294return outputs295
296
297def logits_loss(output_logits, final_label):298# How closely the final logit matches the label.299output_probs = nn.softmax(output_logits[:, -1], axis=-1)300return -output_probs[:, final_label]301
302
303def filter_onehot(input_onehot, vocab_mask):304# This function only works with positive inputs.305input_onehot = jnp.where(vocab_mask[None, None, :], input_onehot, 0.0)306return input_onehot / jnp.sum(input_onehot, axis=-1, keepdims=True)307
308
309def difference_loss(310input_logits,311input_onehot,312output_logits,313vocab_mask,314):315"""How much do the updated token probabilities differ from the predicted ones.316"""
317del input_onehot318
319normalized_output_logits = jax.nn.log_softmax(output_logits, axis=-1)320input_probability = jax.nn.softmax(input_logits, axis=-1)321input_probability = filter_onehot(input_probability, vocab_mask)322
323result_logits = input_probability * normalized_output_logits324result_logits = jnp.sum(result_logits, -1)325
326# Divide by a constant to scale the loss.327return -result_logits / 30.0328
329
330def loss_fn(331input_logits,332mdl_vars,333full_input,334prng_key,335input_gs_params,336decode_gs_params,337difference_loss_weight,338batch_size,339is_eval,340wrapped_model,341return_aux_tokens,342):343"""Computes predicted output label and compares against true label.344
345The input to simple_decode is constructed from full_input.tokens, and
346input_logits.
347If there are tokens in input_for_classify, the input is
348prefix, input_logits, sep, decode. The decode tokens are placeholders.
349simple_decode uses this input to get the output tokens.
350The next input is input_for_classify, sep, decode_output, classify.
351simple_decode uses this input to get the classifier score.
352Note that only input_for_classify is given when calculating the classifier
353score, not input_logits.
354
355If there are no input_for_classify tokens, the input is
356prefix, input_logits, sep, decode, classify
357simple_decode gets the decoded model output, then adds the classify tokens
358to get the classify score.
359
360Args:
361input_logits: The adversarial logits. The gradient is computed with respect
362to these.
363mdl_vars: Model vars
364full_input: See make_inputs. Use input_logits instead of the logits from
365here.
366prng_key: Rand key.
367input_gs_params: The temp and hard for the gumbel softmax on the input.
368decode_gs_params: The temp and hard for the gumbel softmax during decode.
369difference_loss_weight: How much optimize the probability of the input.
370batch_size: batch_size
371is_eval: True to use maximum instead of sampling with gumbel softmax.
372wrapped_model: Pax model.
373return_aux_tokens: True to return additional tokens for logging.
374
375Returns:
376total_loss: weighted sum of adversarial and difference losses
377losses: Map of loss and difference_loss
378aux_tokens: Returned if return_aux_tokens.
379"""
380input_logits_batch = replicate_batch(input_logits, batch_size)381
382if is_eval:383# Take the softmax because384# construct_decode_input only works with positive inputs.385input_onehot = jax.nn.softmax(input_logits_batch, axis=-1)386assert batch_size == 1387else:388prng_key, gs_prng_key = jax.random.split(prng_key)389input_onehot = gumbel_softmax(input_logits_batch, input_gs_params['temp'],390input_gs_params['hard'], gs_prng_key)391
392input_onehot = filter_onehot(input_onehot, full_input.vocab_mask)393if is_eval:394input_onehot = calc_max_onehot(input_onehot)395
396# Converts the tokens to onehot representation.397onehot_tokens = {}398for token_type, tokens in full_input.tokens.items():399onehot_tokens[token_type] = replicate_batch(400jax.nn.one_hot(401tokens, input_onehot.shape[-1], dtype=input_onehot.dtype402),403input_onehot.shape[0],404)405
406# Construct the input.407decode_prefix_onehot = jnp.concatenate(408[onehot_tokens['prefix'], input_onehot, onehot_tokens['sep']], axis=1)409decode_prefix_len = decode_prefix_onehot.shape[1]410use_decoded_mask = jnp.concatenate([411jnp.zeros(decode_prefix_onehot.shape[1]),412jnp.ones(onehot_tokens['decode'].shape[1])413],414axis=0)415decode_input_onehot = jnp.concatenate(416[decode_prefix_onehot, onehot_tokens['decode']], axis=1)417
418has_input_for_classify = full_input.tokens.input_for_classify.shape[0] > 0419
420if not has_input_for_classify:421# Add the classify tokens.422decode_input_onehot = jnp.concatenate(423[decode_input_onehot, onehot_tokens['classify']], axis=1)424use_decoded_mask = jnp.concatenate(425[use_decoded_mask,426jnp.zeros(onehot_tokens['classify'].shape[1])],427axis=0)428
429prng_key, decode_key = jax.random.split(prng_key)430decode_out = simple_decode( # pytype: disable=wrong-arg-types # jnp-type431mdl_vars,432decode_input_onehot,433use_decoded_mask,434decode_prefix_len,435decode_gs_params,436decode_key,437greedy=is_eval,438wrapped_model=wrapped_model,439)440
441# Calc difference_loss442update_start = full_input.tokens['prefix'].shape[0] - 1443# Axis 1 not 0 because input_onehot has been replicated into multiple batches.444update_end = update_start + input_onehot.shape[1]445updated_decoded_logits = decode_out['logits'][:, update_start:update_end, :]446difference_loss_value = jnp.mean(447difference_loss(448input_logits_batch,449input_onehot,450updated_decoded_logits,451full_input.vocab_mask,452)453)454
455if has_input_for_classify:456# Construct the classifier input.457classify_input_onehot = jnp.concatenate([458onehot_tokens['input_for_classify'], onehot_tokens['sep'],459decode_out['onehot'], onehot_tokens['classify']460],461axis=1)462decode_mask_classify = jnp.zeros(classify_input_onehot.shape[1])463classify_prefix_len = len(decode_mask_classify)464
465prng_key, decode_key = jax.random.split(prng_key)466classify_out = simple_decode( # pytype: disable=wrong-arg-types # jnp-array467mdl_vars,468classify_input_onehot,469decode_mask_classify,470classify_prefix_len,471decode_gs_params,472decode_key,473greedy=is_eval,474wrapped_model=wrapped_model,475)476
477loss = jnp.mean(logits_loss(classify_out['logits'], full_input.label))478
479else:480loss = jnp.mean(logits_loss(decode_out['logits'], full_input.label))481
482total_loss = loss + difference_loss_value * difference_loss_weight483losses = {'loss': loss, 'difference_loss': difference_loss_value}484
485if return_aux_tokens:486aux_tokens = {487'decode_prefix': jnp.argmax(decode_prefix_onehot, -1),488'decode_input': jnp.argmax(decode_input_onehot, -1),489'decode_out_onehot': jnp.argmax(decode_out['onehot'], -1),490'decode_out_logits': jnp.argmax(decode_out['logits'], -1),491}492if has_input_for_classify:493aux_tokens['classify_input'] = jnp.argmax(classify_input_onehot, -1)494return total_loss, losses, aux_tokens495else:496return total_loss, losses497
498
499loss_fn_jit = jax.jit(500loss_fn,501static_argnames=[502'wrapped_model',503'batch_size',504'is_eval',505'return_aux_tokens',506],507)
508
509
510def loss_grad(511full_input,512model_states,513prng_key,514input_gs_params,515decode_gs_params,516difference_loss_weight,517batch_size,518wrapped_model,519):520"""Returns the loss, and gradient of loss_fn."""521(_, loss), grad = jax.value_and_grad(loss_fn, has_aux=True)(522full_input.logits,523model_states.mdl_vars,524full_input,525prng_key,526input_gs_params,527decode_gs_params,528difference_loss_weight,529batch_size,530is_eval=False,531wrapped_model=wrapped_model,532return_aux_tokens=False,533)534return loss, grad535
536
537@functools.partial(538jax.pmap,539in_axes=(0, 0, 0, None, None, None, None, None, None),540static_broadcasted_argnums=[7, 8],541axis_name='batch')542# Arguments must be passed by position not keyword because of pmap
543def update_input_rep_par(544full_input,545model_states,546prng_key,547lr,548input_gs_params,549decode_gs_params,550difference_loss_weight,551local_batch_size,552wrapped_model,553):554"""Updates the input logits to minimize the loss."""555prng_key, loss_rng = jax.random.split(prng_key)556
557loss, grad = loss_grad(558full_input,559model_states,560loss_rng,561input_gs_params,562decode_gs_params,563difference_loss_weight,564local_batch_size,565wrapped_model,566)567
568grad = jax.lax.pmean(grad, axis_name='batch')569loss = jax.lax.pmean(loss, axis_name='batch')570
571optimizer = optax.adam(lr)572updates, opt_state = optimizer.update(grad, model_states.opt_states)573input_logits = optax.apply_updates(full_input.logits, updates)574
575return input_logits, opt_state, loss, prng_key576
577
578def eval_label_prob(579full_input,580model_states,581verbose,582vocabulary,583wrapped_model,584):585"""Computes the probability of the label after the decoding step.586
587Uses the token with the highest probability for the input, and during
588decoding. It doesn't use gumbel softmax.
589
590Args:
591full_input: The result of make_inputs. Used for the logits, and vocab_mask.
592model_states: For the model.
593verbose: Prints the inputs and outputs if True.
594vocabulary: Model vocab.
595wrapped_model: Task including the model.
596
597Returns:
598The probability of the label after the decoding.
599The full tokens that are used to calculate the score. This is the
600adversarial input followed by the separator, followed by the decoded output.
601"""
602_, losses, aux_tokens = loss_fn_jit(603full_input.logits,604model_states.mdl_vars,605full_input,606jax.random.PRNGKey(0),607input_gs_params=None,608decode_gs_params=None,609difference_loss_weight=0.0,610batch_size=1,611is_eval=True,612wrapped_model=wrapped_model,613return_aux_tokens=True,614)615loss = losses['loss']616
617if verbose:618display_dict = collections.defaultdict(list)619
620for key, tokens in aux_tokens.items():621display_dict[key] = _display_tokens(tokens[0], vocabulary)622
623display_dict = {k: pd.Series(v) for k, v in display_dict.items()}624
625print(pd.DataFrame(display_dict).to_string())626
627decode_out_tokens = aux_tokens['decode_out_onehot'][6280, :len(full_input.tokens['decode']) + 1]629decode_input_output = jnp.concatenate(630[aux_tokens['decode_prefix'][0], decode_out_tokens])631
632return -loss, losses['difference_loss'], decode_input_output633
634
635def make_model_input(tokens):636num_tokens = len(tokens)637return NestedMap.FromNestedDict({638'ids': tokens,639'labels': np.zeros([num_tokens], np.int32),640'paddings': np.zeros([num_tokens], np.int32),641'weights': np.zeros([num_tokens], np.int32),642'segment_ids': None,643'segment_pos': None,644})645
646
647@functools.partial(jax.jit, static_argnames='wrapped_model')648def regular_decode(model_states, input_tokens, wrapped_model):649"""Decodes using the built in PAX decoding."""650model_input = replicate_batch_tree(make_model_input(input_tokens), 1)651var_weight_hparams = wrapped_model.model.abstract_init_with_metadata(652model_input
653)654(_, per_example_out, _), _ = trainer_lib.decode_step(655wrapped_model.model,656model_states.to_eval_state(),657jax.random.PRNGKey(1234),658var_weight_hparams,659model_input,660fprop_dtype=wrapped_model.model.fprop_dtype,661)662return per_example_out663
664
665def dec_enc(tokens,666vocabulary):667"""Decodes the tokens with the vocab then encodes them again.668
669It will usually give the same result. This is used to make sure the input is
670tokens which are possible.
671
672Args:
673tokens: tokens
674vocabulary: vocabulary
675
676Returns:
677The tokens after decoding then encoding.
678"""
679return jnp.array([0] + list(vocabulary.encode(vocabulary.decode(tokens))))680
681
682def filter_after_eos(tokens):683"""Removes all tokens after the eos token (eos token has value 1)."""684# The 0th element of where is the first occurrence.685index = jnp.where(tokens == 1)[0]686if index.shape[0] == 0:687return tokens688return tokens[:index[0]]689
690
691def eval_label_prob_reg_decode(692full_input,693model_states,694verbose,695vocabulary,696wrapped_model,697use_dec_enc,698):699"""Similar to eval_label_prob, but it uses the normal PAX decode algorithm.700
701Uses tokens instead of a one hot encoding.
702Usually it has the same results as eval_label_prob.
703But sometimes the results are different due to floating point errors.
704
705Args:
706full_input: The result of make_inputs. Used for the logits, and vocab_mask.
707model_states: For the model.
708verbose: Prints the inputs and outputs if True.
709vocabulary: Model vocab.
710wrapped_model: Task including the model.
711use_dec_enc: True to apply the dec_enc to the input.
712
713Returns:
714The probability of the label after the decoding.
715The full tokens that are used to calculate the score. This is the
716adversarial input followed by the separator, followed by the decoded output.
717"""
718input_onehot = jax.nn.softmax(full_input.logits, axis=-1)719input_onehot = filter_onehot(input_onehot, full_input.vocab_mask)[0]720input_tokens = jnp.argmax(input_onehot, -1)721
722if use_dec_enc:723input_tokens = dec_enc(input_tokens, vocabulary)724full_input_tokens = jnp.concatenate(725[full_input.tokens.prefix, input_tokens, full_input.tokens.sep], 0)726if use_dec_enc:727full_input_tokens = dec_enc(full_input_tokens, vocabulary)728
729decode_out = regular_decode(model_states, full_input_tokens, wrapped_model)730
731decode_end_i = full_input_tokens.shape[0] + full_input.tokens.decode.shape[0]732if full_input.tokens.input_for_classify.shape[0] > 0:733decode_end_i += 1734decoded_tokens = decode_out['output_ids'][7350, 0, full_input_tokens.shape[0]:decode_end_i]736decoded_tokens = filter_after_eos(decoded_tokens)737
738full_decode_tokens = jnp.concatenate([full_input_tokens, decoded_tokens], 0)739
740if full_input.tokens.input_for_classify.shape[0] > 0:741safety_classifier_input = jnp.concatenate([742full_input.tokens.input_for_classify, full_input.tokens.sep,743decoded_tokens, full_input.tokens.classify744], 0)745else:746safety_classifier_input = jnp.concatenate(747[full_decode_tokens, full_input.tokens.classify], 0)748
749safety_out = regular_decode(750model_states, safety_classifier_input, wrapped_model751)752
753safety_token = safety_out['output_ids'][0, 0,754safety_classifier_input.shape[0]]755safety_token_prob = jnp.exp(756safety_out['logprobs'][0, 0, safety_classifier_input.shape[0]])757
758if verbose:759display_dict = collections.defaultdict(list)760
761display_dict['decoded_tokens'] = _display_tokens(full_decode_tokens,762vocabulary)763display_dict['safety_out'] = _display_tokens(764safety_out['output_ids'][0, 0, :safety_classifier_input.shape[0] + 1],765vocabulary)766
767display_dict = {k: pd.Series(v) for k, v in display_dict.items()}768
769print(pd.DataFrame(display_dict).to_string())770
771if safety_token == full_input.label:772return safety_token_prob, full_decode_tokens773else:774return 1 - safety_token_prob, full_decode_tokens775
776
777def _display_tokens(778tokens,779vocabulary):780display_vals = []781for token in tokens:782display_vals.append(vocabulary.decode([token]))783return display_vals784