google-research

Форк
0
1089 строк · 36.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""The actual Contrack model."""
17

18
import logging
19
import math
20
from typing import Any, Callable, Dict, Iterator, List, Tuple
21

22
import tensorflow as tf
23

24
from contrack import custom_ops
25
from contrack import encoding
26
from contrack.env import Env
27

28

29
def _pad_and_clip_seq_batch(seq_batch, seq_len_batch,
30
                            pad_value, maxlen,
31
                            data_vec_len):
32
  """Pads a batch of sequences with a padding value up to a length."""
33
  with tf.name_scope('pad_seq_batch'):
34
    seq_mask = tf.sequence_mask(lengths=seq_len_batch, maxlen=maxlen)
35
    seq_mask = tf.expand_dims(seq_mask, 2)
36
    seq_mask = tf.tile(seq_mask, [1, 1, data_vec_len])
37
    # Trim or pad seq_batch as needed to make the shapes compatible
38
    padded_shape = [tf.shape(seq_batch)[0], maxlen, data_vec_len]
39
    seq_batch = seq_batch[:, :maxlen, :]
40
    seq_dim_pad_len = tf.constant(maxlen) - tf.shape(seq_batch)[1]
41
    seq_batch = tf.pad(
42
        seq_batch, paddings=[[0, 0], [0, seq_dim_pad_len], [0, 0]])
43
    seq_batch.set_shape([seq_batch.shape[0], maxlen, data_vec_len])
44
    pad_value = tf.cast(pad_value, dtype=seq_batch.dtype)
45
    pad_batch = tf.fill(padded_shape, value=pad_value)
46
    padded_seq_batch = tf.where(seq_mask, seq_batch, pad_batch)
47
    return padded_seq_batch
48

49

50
def shape_list(x):
51
  """Return list of dims, statically where possible."""
52
  x = tf.convert_to_tensor(x)
53

54
  # If unknown rank, return dynamic shape
55
  if x.get_shape().dims is None:
56
    return tf.shape(x)
57

58
  static = x.get_shape().as_list()
59
  shape = tf.shape(x)
60

61
  ret = []
62
  for i, dim in enumerate(static):
63
    if dim is None:
64
      dim = shape[i]
65
    ret.append(dim)
66
  return ret
67

68

69
def split_heads(x, n):
70
  x_shape = shape_list(x)
71
  m = x_shape[-1]
72
  if isinstance(m, int) and isinstance(n, int):
73
    assert m % n == 0
74
  y = tf.reshape(x, x_shape[:-1] + [n, m // n])
75

76
  return tf.transpose(y, [0, 2, 1, 3])
77

78

79
def combine_heads(x):
80
  x = tf.transpose(x, [0, 2, 1, 3])
81
  x_shape = shape_list(x)
82
  a, b = x_shape[-2:]
83
  return tf.reshape(x, x_shape[:-2] + [a * b])
84

85

86
class ConvertToSequenceLayer(tf.keras.layers.Layer):
87
  """Concatenates input data into a sequence suitable for prediction."""
88

89
  def __init__(self, input_vec_len):
90
    super(ConvertToSequenceLayer, self).__init__()
91
    self.config = Env.get().config
92
    self.input_vec_len = input_vec_len
93

94
  @classmethod
95
  def from_config(cls, config):
96
    return ConvertToSequenceLayer(config['input_vec_len'])
97

98
  def get_config(self):
99
    return {'input_vec_len': self.input_vec_len}
100

101
  def compute_mask(self, inputs, mask=None):
102
    state_seq_len = inputs['state_seq_length']
103
    token_seq_len = inputs['token_seq_length']
104

105
    input_seq_len = tf.add(state_seq_len, token_seq_len)
106

107
    return tf.sequence_mask(input_seq_len, maxlen=self.config.max_seq_len)
108

109
  def call(self,
110
           inputs,
111
           training = None):
112
    with tf.name_scope('convert_to_sequence'):
113
      state_seq_len = tf.cast(inputs['state_seq_length'], tf.int32)
114
      state_seq = inputs['state_seq']
115

116
      token_seq_len = tf.cast(inputs['token_seq_length'], tf.int32)
117
      token_seq = inputs['token_seq']
118

119
      input_seq, input_seq_len = custom_ops.sequence_concat(
120
          sequences=[state_seq, token_seq],
121
          lengths=[state_seq_len, token_seq_len])
122

123
      # Clip and pad seq
124
      input_seq_len = tf.minimum(
125
          input_seq_len, self.config.max_seq_len, name='input_seq_len')
126
      input_seq = _pad_and_clip_seq_batch(
127
          input_seq,
128
          input_seq_len,
129
          pad_value=0,
130
          maxlen=self.config.max_seq_len,
131
          data_vec_len=self.input_vec_len)
132

133
      # Add timing signal
134
      if self.config.timing_signal_size > 0:
135
        num_channels = self.config.timing_signal_size
136
        positions = tf.cast(tf.range(self.config.max_seq_len), dtype=tf.int64)
137

138
        min_timescale = 1.0
139
        max_timescale = 1.0e4
140

141
        with tf.name_scope('TimingSignal'):
142
          num_timescales = num_channels // 2
143
          log_timescale_increment = (
144
              math.log(max_timescale / min_timescale) /
145
              (tf.cast(num_timescales, tf.float32) - 1))
146
          inv_timescales = min_timescale * tf.exp(
147
              tf.cast(tf.range(num_timescales), tf.float32) *
148
              -log_timescale_increment)
149
          scaled_time = (
150
              tf.expand_dims(tf.cast(positions, tf.float32), 1) *
151
              tf.expand_dims(inv_timescales, 0))
152
          signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
153
          time_signal = tf.pad(
154
              signal, [[0, 0], [0, tf.math.floormod(num_channels, 2)]])
155

156
          time_signal = tf.expand_dims(time_signal, 0)
157
          time_signal = tf.tile(time_signal, [tf.shape(input_seq)[0], 1, 1])
158
          seq_mask = tf.cast(
159
              tf.sequence_mask(
160
                  lengths=input_seq_len, maxlen=self.config.max_seq_len),
161
              dtype=tf.float32)
162
          seq_mask = tf.expand_dims(seq_mask, 2)
163
          seq_mask = tf.tile(seq_mask, [1, 1, num_channels])
164
          time_signal = time_signal * seq_mask
165
          input_seq = tf.concat([input_seq, time_signal],
166
                                axis=2,
167
                                name='add_time_signal')
168

169
      return input_seq, input_seq_len
170

171

172
class IdentifyNewEntityLayer(tf.keras.layers.Layer):
173
  """The layer identifying new entities in a message."""
174

175
  def __init__(self, seq_shape):
176
    super(IdentifyNewEntityLayer, self).__init__()
177
    self.config = Env.get().config.new_id_attention
178
    self.supports_masking = True
179
    self.seq_shape = seq_shape
180

181
    self.batch_size = seq_shape[0]
182
    self.seq_len = seq_shape[1]
183
    key_dim = math.ceil(seq_shape[2] / self.config.num_heads)
184
    output_shape = [
185
        self.batch_size, self.seq_len, key_dim * self.config.num_heads
186
    ]
187
    self.attention = tf.keras.layers.MultiHeadAttention(
188
        num_heads=self.config.num_heads,
189
        key_dim=key_dim,
190
        value_dim=key_dim,
191
        use_bias=True,
192
        dropout=self.config.dropout_rate,
193
        output_shape=[output_shape[2]])
194

195
    self.layer_norm = tf.keras.layers.LayerNormalization(
196
        axis=2, epsilon=1e-6, name='SelfAttentionNorm')
197

198
    self.affine = tf.keras.layers.Dense(2, use_bias=True)
199

200
    self.q_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)
201
    self.k_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)
202
    self.v_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)
203

204
    hidden_size = 100
205
    filter_size = 800
206
    self.attention_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)
207
    self.match_residual_dense = tf.keras.layers.Dense(
208
        hidden_size, use_bias=True)
209
    self.ffn_layer_norm = tf.keras.layers.LayerNormalization(
210
        axis=2, epsilon=1e-6, name='FFNNorm')
211

212
    self.ff_relu_dense = tf.keras.layers.Dense(
213
        filter_size, use_bias=True, activation='relu')
214
    self.ff_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)
215

216
  @classmethod
217
  def from_config(cls, config):
218
    return IdentifyNewEntityLayer(config['seq_shape'])
219

220
  def get_config(self):
221
    return {'seq_shape': self.seq_shape}
222

223
  def compute_mask(self,
224
                   inputs,
225
                   mask = None):
226
    return mask
227

228
  def call(self,
229
           inputs,
230
           training = None,
231
           mask = None):
232
    x = inputs
233

234
    # Make feature space size a multiple of num_heads
235
    num_heads = self.config.num_heads
236
    if x.shape[-1] % num_heads > 0:
237
      with tf.name_scope('PadForMultipleOfHeads'):
238
        fill_size = num_heads - x.shape[-1] % num_heads
239
        fill_mat = tf.tile(tf.zeros_like(x[:, :, :1]), [1, 1, fill_size])
240
        x = tf.concat([x, fill_mat], 2)
241

242
    # Multihead Attention
243
    input_depth = x.shape[-1]
244
    x = self.layer_norm(x, training=training)
245
    q = split_heads(self.q_dense(x, training=training), num_heads)
246
    k = split_heads(self.k_dense(x, training=training), num_heads)
247
    v = split_heads(self.v_dense(x, training=training), num_heads)
248

249
    key_depth_per_head = input_depth // num_heads
250
    q *= key_depth_per_head**-0.5
251

252
    logits = tf.matmul(q, k, transpose_b=True)
253
    weights = tf.nn.softmax(logits, name='attention_weights')
254
    y = tf.matmul(weights, v)
255

256
    y = combine_heads(y)
257
    y = self.attention_dense(y)
258

259
    r = self.match_residual_dense(x)
260
    y = self.ffn_layer_norm(y + r, training=training)
261

262
    # Feed forward
263
    z = self.ff_relu_dense(y, training=training)
264
    z = self.ff_dense(z, training=training)
265
    z += y
266

267
    # Affine layer
268
    z = tf.concat([inputs[:, :, :68], z], 2)
269
    logits = self.affine(z, training=training)
270

271
    # Apply mask
272
    logits *= tf.expand_dims(tf.cast(mask, tf.float32), -1)
273

274
    return logits
275

276

277
class ComputeIdsLayer(tf.keras.layers.Layer):
278
  """Compute Ids for the (new entity) tokens in the input sequence."""
279

280
  def __init__(self):
281
    super(ComputeIdsLayer, self).__init__()
282
    self.encodings = Env.get().encodings
283
    self.config = Env.get().config
284

285
  @classmethod
286
  def from_config(cls, config):
287
    del config
288
    return ComputeIdsLayer()
289

290
  def get_config(self):
291
    return {}
292

293
  def compute_mask(self,
294
                   inputs,
295
                   mask = None):
296
    _, seq_len, _ = inputs
297
    return tf.sequence_mask(seq_len, maxlen=self.config.max_seq_len)
298

299
  def call(self,
300
           inputs,
301
           mask = None):
302
    seq, enref_seq_len, is_new_logits = inputs
303

304
    enref_seq_len = tf.cast(enref_seq_len, dtype=tf.int32)
305

306
    enref_ids = self.encodings.as_enref_encoding(seq).enref_id.slice()
307

308
    is_new_entity = tf.cast(is_new_logits[:, :, 0] > 0.0, tf.float32)
309

310
    new_id_one_hot = custom_ops.new_id(
311
        state_ids=enref_ids, state_len=enref_seq_len, is_new=is_new_entity)
312
    new_id_one_hot = tf.stop_gradient(new_id_one_hot)
313

314
    return new_id_one_hot
315

316

317
class TrackEnrefsLayer(tf.keras.layers.Layer):
318
  """Predict enref Ids, properties, ane group membership."""
319

320
  def __init__(self, seq_shape):
321
    super(TrackEnrefsLayer, self).__init__()
322

323
    self.config = Env.get().config.tracking_attention
324
    self.encodings = Env.get().encodings
325
    self.supports_masking = True
326
    self.seq_shape = seq_shape
327

328
    self.batch_size = seq_shape[0]
329
    self.seq_len = seq_shape[1]
330
    attention_input_length = (
331
        seq_shape[2] + 2 + self.encodings.new_enref_encoding().enref_id.SIZE)
332
    key_dim = math.ceil(attention_input_length / self.config.num_heads)
333
    output_shape = [
334
        self.batch_size, self.seq_len, key_dim * self.config.num_heads
335
    ]
336
    self.attention = tf.keras.layers.MultiHeadAttention(
337
        num_heads=self.config.num_heads,
338
        key_dim=key_dim,
339
        value_dim=key_dim,
340
        use_bias=True,
341
        dropout=self.config.dropout_rate,
342
        output_shape=[output_shape[2]])
343

344
    self.layer_norm = tf.keras.layers.LayerNormalization(
345
        axis=2, epsilon=1e-6, name='SelfAttentionNorm')
346

347
    self.affine = tf.keras.layers.Dense(
348
        self.encodings.prediction_encoding_length, use_bias=True)
349

350
    self.q_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)
351
    self.k_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)
352
    self.v_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)
353

354
    hidden_size = 100
355
    filter_size = 800
356
    self.attention_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)
357
    self.match_residual_dense = tf.keras.layers.Dense(
358
        hidden_size, use_bias=True)
359
    self.ffn_layer_norm = tf.keras.layers.LayerNormalization(
360
        axis=2, epsilon=1e-6, name='FFNNorm')
361

362
    self.ff_relu_dense = tf.keras.layers.Dense(
363
        filter_size, use_bias=True, activation='relu')
364
    self.ff_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)
365

366
  @classmethod
367
  def from_config(cls, config):
368
    return IdentifyNewEntityLayer(config['seq_shape'])
369

370
  def get_config(self):
371
    return {'seq_shape': self.seq_shape}
372

373
  def compute_mask(self,
374
                   inputs,
375
                   mask = None):
376
    return mask[0]
377

378
  def call(self,
379
           inputs,
380
           training = None,
381
           mask = None):
382
    seq, is_new_entity, new_ids = inputs
383

384
    is_new_entity = tf.stop_gradient(is_new_entity)
385

386
    x = tf.concat([seq, is_new_entity, new_ids], axis=2)
387

388
    # Make feature space size a multiple of num_heads
389
    num_heads = self.config.num_heads
390
    if x.shape[-1] % num_heads > 0:
391
      with tf.name_scope('PadForMultipleOfHeads'):
392
        fill_size = num_heads - x.shape[-1] % num_heads
393
        fill_mat = tf.tile(tf.zeros_like(x[:, :, :1]), [1, 1, fill_size])
394
        x = tf.concat([x, fill_mat], 2)
395

396
    # Multihead Attention
397
    input_depth = x.shape[-1]
398
    x = self.layer_norm(x, training=training)
399
    q = split_heads(self.q_dense(x, training=training), num_heads)
400
    k = split_heads(self.k_dense(x, training=training), num_heads)
401
    v = split_heads(self.v_dense(x, training=training), num_heads)
402

403
    key_depth_per_head = input_depth // num_heads
404
    q *= key_depth_per_head**-0.5
405

406
    logits = tf.matmul(q, k, transpose_b=True)
407
    weights = tf.nn.softmax(logits, name='attention_weights')
408
    y = tf.matmul(weights, v)
409

410
    y = combine_heads(y)
411
    y = self.attention_dense(y)
412

413
    r = self.match_residual_dense(x)
414
    y = self.ffn_layer_norm(y + r, training=training)
415

416
    # Feed forward
417
    z = self.ff_relu_dense(y, training=training)
418
    z = self.ff_dense(z, training=training)
419
    z += y
420

421
    # Affine layer
422
    logits = self.affine(z, training=training)
423

424
    # Apply mask
425
    logits *= tf.expand_dims(tf.cast(mask[0], tf.float32), -1)
426

427
    return logits
428

429

430
class MergeIdsLayer(tf.keras.layers.Layer):
431
  """Layer merging the ids from new entities and existing entities."""
432

433
  def __init__(self):
434
    super(MergeIdsLayer, self).__init__()
435

436
    self.encodings = Env.get().encodings
437

438
  @classmethod
439
  def from_config(cls, config):
440
    del config
441
    return MergeIdsLayer()
442

443
  def get_config(self):
444
    return {}
445

446
  def call(self,
447
           inputs,
448
           training = None,
449
           mask = None):
450
    is_new_entity, new_ids, logits = inputs
451

452
    logits_encoding = self.encodings.as_prediction_encoding(logits)
453
    existing_ids = logits_encoding.enref_id.slice()
454

455
    is_new_id = tf.cast(is_new_entity > 0.0, tf.float32)
456
    is_new_id = tf.reduce_max(is_new_id, 2, keepdims=True)
457
    ids = is_new_id * new_ids
458
    ids += (1.0 - is_new_id) * existing_ids
459

460
    logits = logits_encoding.enref_id.replace(ids)
461
    logits = self.encodings.as_prediction_encoding(
462
        logits).enref_meta.replace_is_new_slice(is_new_entity)
463

464
    return logits
465

466

467
class ContrackModel(tf.keras.Model):
468
  """The Contrack model."""
469

470
  def __init__(self, mode, print_predictions = False):
471
    super(ContrackModel, self).__init__()
472

473
    self.config = Env.get().config
474
    self.encodings = Env.get().encodings
475
    self.mode = mode
476
    self.print_predictions = print_predictions
477
    self.teacher_forcing = True
478

479
    self.convert_to_sequence_layer = ConvertToSequenceLayer(
480
        self.encodings.enref_encoding_length)
481

482
    input_shape = [
483
        self.config.batch_size, self.config.max_seq_len,
484
        self.encodings.enref_encoding_length + self.config.timing_signal_size
485
    ]
486
    self.new_entity_layer = IdentifyNewEntityLayer(input_shape)
487

488
    self.compute_ids_layer = ComputeIdsLayer()
489

490
    self.track_enrefs_layer = TrackEnrefsLayer(input_shape)
491

492
    self.merge_ids_layer = MergeIdsLayer()
493

494
  @classmethod
495
  def from_config(cls, config):
496
    return ContrackModel(config['mode'])
497

498
  def get_config(self):
499
    return {'mode': self.mode}
500

501
  def init_weights_from_new_entity_model(self, model):
502
    # Call the model once to create weights
503
    input_vec_shape = [
504
        self.config.batch_size, self.config.max_seq_len,
505
        self.encodings.token_encoding_length
506
    ]
507
    null_input = {
508
        'state_seq_length': tf.ones([self.config.batch_size]),
509
        'state_seq': tf.zeros(input_vec_shape),
510
        'token_seq_length': tf.ones([self.config.batch_size]),
511
        'token_seq': tf.zeros(input_vec_shape)
512
    }
513
    self(null_input)
514

515
    # Then copy over layer weights
516
    self.new_entity_layer.set_weights(model.new_entity_layer.get_weights())
517

518
  def call(self,
519
           inputs,
520
           training = False):
521
    # Step 1: Concatenate input data into a single sequence
522
    seq, _ = self.convert_to_sequence_layer(inputs)
523

524
    # Step 2: Identify new entities.
525
    is_new_entity = self.new_entity_layer(seq)
526

527
    if self.mode == 'only_new_entities':
528
      res = tf.zeros_like(seq[:, :, :self.encodings.prediction_encoding_length])
529
      res_enc = self.encodings.as_prediction_encoding(res)
530
      res = res_enc.enref_meta.replace_is_new_slice(is_new_entity)
531
      return res
532
    elif self.mode == 'only_tracking':
533
      is_new_entity = tf.stop_gradient(is_new_entity)
534

535
    # Step 3: Compute enref ids for new entities
536
    new_ids = self.compute_ids_layer(
537
        (seq, inputs['state_seq_length'], is_new_entity))
538

539
    # Step 4: Determine enref predictions
540
    logits = self.track_enrefs_layer((seq, is_new_entity, new_ids))
541

542
    # Step 5: Merge ids from new and existing enrefs
543
    logits = self.merge_ids_layer((is_new_entity, new_ids, logits))
544

545
    return logits
546

547
  def train_step(self, data):
548
    """The training step."""
549
    x = data
550

551
    # Shift true labels seq to align with tokens in input_seq
552
    state_seq_len = tf.cast(data['state_seq_length'], tf.int32)
553
    token_seq_len = tf.cast(data['token_seq_length'], tf.int32)
554
    state_seq_dims = tf.shape(data['state_seq'])
555
    enref_padding = tf.zeros([
556
        state_seq_dims[0], state_seq_dims[1],
557
        self.encodings.prediction_encoding_length
558
    ])
559
    y, y_len = custom_ops.sequence_concat(
560
        sequences=[enref_padding, data['annotation_seq']],
561
        lengths=[state_seq_len, token_seq_len])
562

563
    # Clip and pad true labels seq
564
    y_len = tf.minimum(y_len, self.config.max_seq_len, name='y_len')
565
    y = _pad_and_clip_seq_batch(
566
        y,
567
        y_len,
568
        pad_value=0,
569
        maxlen=self.config.max_seq_len,
570
        data_vec_len=self.encodings.prediction_encoding_length)
571

572
    input_seq_len = tf.add(state_seq_len, token_seq_len)
573
    seq_mask = tf.sequence_mask(
574
        input_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)
575
    enref_mask = tf.sequence_mask(
576
        state_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)
577
    sample_weight = seq_mask - enref_mask
578

579
    with tf.GradientTape() as tape:
580
      y_pred = self(x, training=True)
581
      loss = self.compiled_loss(
582
          y, y_pred, sample_weight, regularization_losses=self.losses)
583
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
584
    self.compiled_metrics.update_state(y, y_pred, sample_weight)
585
    return {m.name: m.result() for m in self.metrics}
586

587
  def predict_step(
588
      self, data):
589
    """The logic for one inference step."""
590
    if not self.teacher_forcing:
591
      data, y_pred = self.call_without_teacher_forcing(data)
592
    else:
593
      y_pred = self(data, training=False)
594

595
    x = {
596
        'state_seq_length': data['state_seq_length'],
597
        'token_seq_length': data['token_seq_length'],
598
        'scenario_id': data['scenario_id']
599
    }
600
    x['state_seq'] = _pad_and_clip_seq_batch(
601
        data['state_seq'],
602
        data['state_seq_length'],
603
        pad_value=0,
604
        maxlen=self.config.max_seq_len,
605
        data_vec_len=self.encodings.enref_encoding_length)
606
    x['token_seq'] = _pad_and_clip_seq_batch(
607
        data['token_seq'],
608
        data['token_seq_length'],
609
        pad_value=0,
610
        maxlen=self.config.max_seq_len,
611
        data_vec_len=self.encodings.token_encoding_length)
612
    x['word_seq'] = _pad_and_clip_seq_batch(
613
        tf.expand_dims(data['word_seq'], -1),
614
        data['token_seq_length'],
615
        pad_value='',
616
        maxlen=self.config.max_seq_len,
617
        data_vec_len=1)
618
    x['annotation_seq'] = _pad_and_clip_seq_batch(
619
        data['annotation_seq'],
620
        data['token_seq_length'],
621
        pad_value=0,
622
        maxlen=self.config.max_seq_len,
623
        data_vec_len=self.encodings.prediction_encoding_length)
624

625
    return (x, y_pred)
626

627
  def make_test_function(self):
628
    """Creates a function that executes one step of evaluation."""
629
    test_fn = super(ContrackModel, self).make_test_function()
630

631
    def adapted_test_fn(iterator):
632
      outputs = test_fn(iterator)
633
      if 'print_prediction' in outputs:
634
        pred_msgs = outputs['print_prediction']
635
        if self.print_predictions:
636
          logging.info(pred_msgs.numpy().decode('utf-8'))
637
        del outputs['print_prediction']
638
      return outputs
639

640
    self.test_function = adapted_test_fn
641
    return adapted_test_fn
642

643
  def print_prediction(self, seq_len, state_seq_len,
644
                       words, tokens,
645
                       predictions, true_targets):
646
    res = ''
647

648
    for batch_index, num_token in enumerate(seq_len.numpy()):
649
      res += '---------------------------------------\n'
650
      for i in range(num_token):
651
        word = words[batch_index, i].numpy().decode('utf-8')
652
        res += word + ': '
653

654
        seq_index = state_seq_len[batch_index] + i
655
        if seq_index >= self.config.max_seq_len:
656
          break
657

658
        true_target = self.encodings.as_prediction_encoding(
659
            true_targets[batch_index, seq_index, :].numpy())
660

661
        pred = self.encodings.as_prediction_encoding(
662
            predictions[batch_index, seq_index, :].numpy())
663

664
        if self.mode == 'only_new_entities':
665
          true_label = '%s%s' % (
666
              'n' if true_target.enref_meta.is_new() > 0.0 else '',
667
              'c' if true_target.enref_meta.is_new_continued() > 0.0 else '')
668
          predicted_label = '%s%s' % (
669
              'n' if pred.enref_meta.is_new() > 0.0 else '',
670
              'c' if pred.enref_meta.is_new_continued() > 0.0 else '')
671
          if true_label != predicted_label:
672
            res += '*** %s != %s' % (predicted_label, true_label)
673
            res += ' ' + str(pred.enref_meta.slice())
674
          else:
675
            res += predicted_label
676
        else:
677
          token = self.encodings.as_token_encoding(tokens[batch_index, i, :])
678
          true_enref = self.encodings.build_enref_from_prediction(
679
              token, true_target)
680
          pred_enref = self.encodings.build_enref_from_prediction(token, pred)
681
          if str(true_enref) != str(pred_enref):
682
            res += '*** %s != %s' % (str(pred_enref), str(true_enref))
683
            res += ' %s' % str([
684
                round(a, 2)
685
                for a in true_targets[batch_index, seq_index, :].numpy()
686
            ])
687
          else:
688
            res += str(pred_enref) if pred_enref is not None else ''
689

690
        res += '\n'
691

692
    return res
693

694
  def test_step(self, data):
695
    """The logic for one evaluation step."""
696
    x = data
697

698
    # Shift true labels seq to align with tokens in input_seq
699
    state_seq_len = tf.cast(data['state_seq_length'], tf.int32)
700
    token_seq_len = tf.cast(data['token_seq_length'], tf.int32)
701
    state_seq_dims = tf.shape(data['state_seq'])
702
    enref_padding = tf.zeros([
703
        state_seq_dims[0], state_seq_dims[1],
704
        self.encodings.prediction_encoding_length
705
    ])
706
    y, y_len = custom_ops.sequence_concat(
707
        sequences=[enref_padding, data['annotation_seq']],
708
        lengths=[state_seq_len, token_seq_len])
709

710
    # Clip and pad true labels seq
711
    y_len = tf.minimum(y_len, self.config.max_seq_len, name='y_len')
712
    y = _pad_and_clip_seq_batch(
713
        y,
714
        y_len,
715
        pad_value=0,
716
        maxlen=self.config.max_seq_len,
717
        data_vec_len=self.encodings.prediction_encoding_length)
718

719
    input_seq_len = tf.add(state_seq_len, token_seq_len)
720
    seq_mask = tf.sequence_mask(
721
        input_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)
722
    enref_mask = tf.sequence_mask(
723
        state_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)
724
    sample_weight = seq_mask - enref_mask
725

726
    y_pred = self(x, training=False)
727
    # Updates stateful loss metrics.
728
    self.compiled_loss(
729
        y, y_pred, sample_weight, regularization_losses=self.losses)
730

731
    self.compiled_metrics.update_state(y, y_pred, sample_weight)
732

733
    # Print prediction to log
734
    output_tensors = {m.name: m.result() for m in self.metrics}
735

736
    print_prediction_fn = tf.py_function(self.print_prediction, [
737
        x['token_seq_length'], x['state_seq_length'], x['word_seq'],
738
        x['token_seq'], y_pred, y
739
    ], tf.string)
740

741
    output_tensors['print_prediction'] = print_prediction_fn
742
    return output_tensors
743

744
  def disable_teacher_forcing(self):
745
    self.teacher_forcing = False
746
    self.current_scenario = None
747
    self.current_enrefs = []
748
    self.current_participants = []
749
    assert self.config.batch_size == 1
750

751
  def call_without_teacher_forcing(self, data):
752
    scenario_id = data['scenario_id'][0].numpy().decode('utf-8')
753
    # logging.info(scenario_id)
754
    if scenario_id == self.current_scenario:
755
      # Continue existing conversations, create state_seq from enrefs
756
      enrefs = self.current_enrefs
757
      logging.info('Continue conversation with %d enrefs', len(enrefs))
758

759
      data['state_seq_length'] = tf.constant([len(enrefs)], dtype=tf.int64)
760
      sender = data['sender'][0].numpy().decode('utf-8')
761
      for enref in enrefs:
762
        entity_name = enref.entity_name
763
        enref.enref_context.set_is_sender(entity_name == sender)
764
        enref.enref_context.set_is_recipient(
765
            entity_name != sender and entity_name in self.current_participants)
766
        enref.enref_context.set_message_offset(
767
            enref.enref_context.get_message_offset() + 1)
768

769
      # for i, e in enumerate(enrefs):
770
      #   diff = e.array - data['state_seq'][0, i, :].numpy()
771
      #   if np.sum(np.abs(diff)) > 0.1:
772
      #     logging.info('diff for %s: %s', str(e), diff.tolist())
773

774
      enref_seq = [e.array for e in enrefs]
775
      data['state_seq'] = tf.constant([enref_seq], dtype=tf.float32)
776
    else:
777
      # Start new conversation, obtain initial enrefs from participants
778
      logging.info('New conversation, participants %s',
779
                   data['participants'].values)
780
      self.current_scenario = scenario_id
781
      self.current_participants = [
782
          p.numpy().decode('utf-8') for p in data['participants'].values
783
      ]
784
      self.current_enrefs = []
785
      for i in range(0, data['state_seq_length'][0].numpy()):
786
        enref_array = data['state_seq'][0, i].numpy()
787
        enref = self.encodings.as_enref_encoding(enref_array)
788
        enref_name = self.current_participants[i]
789
        enref.populate(enref_name, (i, i + 1), enref_name)
790
        self.current_enrefs.append(enref)
791
      # logging.info('Enrefs: %s', str(self.current_enrefs))
792

793
    # Run model
794
    y_pred = self(data, training=False)
795

796
    # Update set of enrefs from prediction
797
    num_tokens = len(data['word_seq'][0])
798
    num_enrefs = len(data['state_seq'][0])
799

800
    token_encs = [self.encodings.as_token_encoding(
801
        data['token_seq'][0, i, :].numpy()) for i in range(0, num_tokens)]
802
    pred_encs = [self.encodings.as_prediction_encoding(y_pred[0, i, :].numpy())
803
                 for i in range(num_enrefs, min(num_enrefs + num_tokens,
804
                                                self.config.max_seq_len))]
805
    words = [data['word_seq'][0, i].numpy().decode('utf-8')
806
             for i in range(0, num_tokens)]
807
    enrefs = self.encodings.build_enrefs_from_predictions(
808
        token_encs, pred_encs, words, self.current_enrefs)
809
    # logging.info('New Enrefs for %s: %s', words, enrefs)
810
    logging.info('%d new enrefs', len(enrefs))
811
    self.current_enrefs += enrefs
812

813
    return (data, y_pred)
814

815

816
class ContrackLoss(tf.keras.losses.Loss):
817
  """The loss function used for contrack training."""
818

819
  def __init__(self, mode):
820
    super(
821
        ContrackLoss,
822
        self).__init__(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
823
    self.encodings = Env.get().encodings
824
    self.config = Env.get().config
825
    self.mode = mode
826

827
  @classmethod
828
  def from_config(cls, config):
829
    return ContrackLoss(config['mode'])
830

831
  def get_config(self):
832
    return {'mode': self.mode}
833

834
  def _compute_hinge_losses(self, labels,
835
                            logits):
836
    """Computes hinge loss."""
837
    all_ones = tf.ones_like(labels)
838
    labels = tf.math.subtract(2 * labels, all_ones)
839
    losses = tf.nn.relu(
840
        tf.math.subtract(all_ones, tf.math.multiply(labels, logits)))
841
    if len(losses.get_shape()) > 2:
842
      losses = tf.math.reduce_sum(losses, [2])
843
    return losses
844

845
  def _compute_annotation_losses(self, target,
846
                                 predicted):
847
    """Compute loss comparing predicted and actual annotations."""
848
    target_enc = self.encodings.as_prediction_encoding(target)
849
    predicted_enc = self.encodings.as_prediction_encoding(predicted)
850

851
    # Membership loss only for groups
852
    loss = self._compute_hinge_losses(
853
        labels=target_enc.enref_membership.slice(),
854
        logits=predicted_enc.enref_membership.slice())
855
    loss *= target_enc.enref_properties.is_group()
856

857
    # Properties loss
858
    loss += self._compute_hinge_losses(
859
        labels=target_enc.enref_properties.slice(),
860
        logits=predicted_enc.enref_properties.slice())
861

862
    # Entity ID loss
863
    is_new = (
864
        target_enc.enref_meta.is_new() +
865
        target_enc.enref_meta.is_new_continued())
866

867
    existing_entity_loss = self._compute_hinge_losses(
868
        labels=target_enc.enref_id.slice(),
869
        logits=predicted_enc.enref_id.slice())
870
    existing_entity_loss *= 1.0 - is_new
871

872
    new_entity_loss = self._compute_hinge_losses(
873
        labels=target_enc.enref_meta.get_is_new_slice(),
874
        logits=predicted_enc.enref_meta.get_is_new_slice())
875
    new_entity_loss *= is_new
876
    fp_cost = self.config.new_id_false_negative_cost - 1.0
877
    new_entity_loss *= tf.ones_like(new_entity_loss) + fp_cost * is_new
878

879
    loss += existing_entity_loss + new_entity_loss
880

881
    # Is_entity loss
882
    not_an_entity_loss = self._compute_hinge_losses(
883
        labels=target_enc.enref_meta.slice(),
884
        logits=predicted_enc.enref_meta.slice())
885
    loss *= target_enc.enref_meta.is_enref()
886
    loss += not_an_entity_loss
887

888
    return loss
889

890
  def _compute_new_id_losses(self, target,
891
                             predicted):
892
    """Computes the new id losses for each token in each turn."""
893
    target_meta = self.encodings.as_prediction_encoding(target).enref_meta
894
    predicted_meta = self.encodings.as_prediction_encoding(predicted).enref_meta
895

896
    losses = self._compute_hinge_losses(
897
        labels=target_meta.get_is_new_slice(),
898
        logits=predicted_meta.get_is_new_slice())
899
    fn_cost = self.config.new_id_false_negative_cost - 1.0
900
    new_entity_positives = target_meta.is_new() + target_meta.is_new_continued()
901
    losses *= tf.ones_like(losses) + fn_cost * new_entity_positives
902
    return losses
903

904
  def call(self, target, predicted):
905
    """Compute loss comparing predicted and actual annotations."""
906
    with tf.name_scope('contrack_loss'):
907
      if self.mode == 'only_new_entities':
908
        return self._compute_new_id_losses(target, predicted)
909
      else:
910
        return self._compute_annotation_losses(target, predicted)
911

912

913
def _get_named_slices(y_true, logits,
914
                      section_name):
915
  """Returns the slices (given by name) of true and predictied vector."""
916
  is_entity = tf.expand_dims(y_true.enref_meta.is_enref(), 2)
917
  if section_name == 'new_entity':
918
    return (y_true.enref_meta.get_is_new_slice(),
919
            is_entity * logits.enref_meta.get_is_new_slice())
920
  elif section_name == 'entities':
921
    return (y_true.enref_id.slice(), is_entity * logits.enref_id.slice())
922
  elif section_name == 'properties':
923
    return (y_true.enref_properties.slice(),
924
            is_entity * logits.enref_properties.slice())
925
  elif section_name == 'membership':
926
    is_group = tf.expand_dims(y_true.enref_properties.is_group(), 2)
927
    return (y_true.enref_membership.slice(),
928
            is_entity * is_group * logits.enref_membership.slice())
929
  else:
930
    raise ValueError('Unknown section name %s' % section_name)
931

932

933
class ContrackAccuracy(tf.keras.metrics.Mean):
934
  """Computes zero-one accuracy on a given slice of the result vector."""
935

936
  def __init__(self, section_name, dtype=None):
937
    self.encodings = Env.get().encodings
938
    self.section_name = section_name
939
    super(ContrackAccuracy, self).__init__(
940
        name=f'{section_name}/accuracy', dtype=dtype)
941

942
  @classmethod
943
  def from_config(cls, config):
944
    return ContrackAccuracy(config['section_name'])
945

946
  def get_config(self):
947
    return {'section_name': self.section_name}
948

949
  def update_state(self,
950
                   y_true,
951
                   logits,
952
                   sample_weight = None):
953
    y_true, logits = _get_named_slices(
954
        self.encodings.as_prediction_encoding(y_true),
955
        self.encodings.as_prediction_encoding(logits), self.section_name)
956
    y_pred = tf.cast(logits > 0.0, tf.float32)
957

958
    matches = tf.reduce_max(tf.cast(y_true == y_pred, tf.float32), -1)
959

960
    super(ContrackAccuracy, self).update_state(matches, sample_weight)
961

962

963
class ContrackPrecision(tf.keras.metrics.Precision):
964
  """Computes precision on a given slice of the result vector."""
965

966
  def __init__(self, section_name, dtype=None):
967
    self.encodings = Env.get().encodings
968
    self.section_name = section_name
969
    super(ContrackPrecision, self).__init__(
970
        name=f'{section_name}/precision', dtype=dtype)
971

972
  @classmethod
973
  def from_config(cls, config):
974
    return ContrackPrecision(config['section_name'])
975

976
  def get_config(self):
977
    return {'section_name': self.section_name}
978

979
  def update_state(self,
980
                   y_true,
981
                   logits,
982
                   sample_weight = None):
983
    y_true, logits = _get_named_slices(
984
        self.encodings.as_prediction_encoding(y_true),
985
        self.encodings.as_prediction_encoding(logits), self.section_name)
986
    y_pred = tf.cast(logits > 0.0, tf.float32)
987

988
    super(ContrackPrecision, self).update_state(y_true, y_pred, sample_weight)
989

990

991
class ContrackRecall(tf.keras.metrics.Recall):
992
  """Computes recall on a given slice of the result vector."""
993

994
  def __init__(self, section_name, dtype=None):
995
    self.encodings = Env.get().encodings
996
    self.section_name = section_name
997
    super(ContrackRecall, self).__init__(
998
        name=f'{section_name}/recall', dtype=dtype)
999

1000
  @classmethod
1001
  def from_config(cls, config):
1002
    return ContrackRecall(config['section_name'])
1003

1004
  def get_config(self):
1005
    return {'section_name': self.section_name}
1006

1007
  def update_state(self,
1008
                   y_true,
1009
                   logits,
1010
                   sample_weight = None):
1011
    y_true, logits = _get_named_slices(
1012
        self.encodings.as_prediction_encoding(y_true),
1013
        self.encodings.as_prediction_encoding(logits), self.section_name)
1014
    y_pred = tf.cast(logits > 0.0, tf.float32)
1015

1016
    super(ContrackRecall, self).update_state(y_true, y_pred, sample_weight)
1017

1018

1019
class ContrackF1Score(tf.keras.metrics.Metric):
1020
  """Computes the f1 score on a given slice of the result vector."""
1021

1022
  def __init__(self, section_name, dtype=None):
1023
    self.encodings = Env.get().encodings
1024
    self.section_name = section_name
1025
    self.precision = ContrackPrecision(section_name, dtype=dtype)
1026
    self.recall = ContrackRecall(section_name, dtype=dtype)
1027
    super(ContrackF1Score, self).__init__(
1028
        name=f'{section_name}/f1score', dtype=dtype)
1029

1030
  @classmethod
1031
  def from_config(cls, config):
1032
    return ContrackF1Score(config['section_name'])
1033

1034
  def get_config(self):
1035
    return {'section_name': self.section_name}
1036

1037
  def add_weight(self, **kwargs):
1038
    self.precision.add_weight(**kwargs)
1039
    self.recall.add_weight(**kwargs)
1040

1041
  def reset_states(self):
1042
    self.precision.reset_states()
1043
    self.recall.reset_states()
1044

1045
  def result(self):
1046
    precision = self.precision.result()
1047
    recall = self.recall.result()
1048
    return 2.0 * (precision * recall) / (precision + recall +
1049
                                         tf.keras.backend.epsilon())
1050

1051
  def update_state(self,
1052
                   y_true,
1053
                   logits,
1054
                   sample_weight = None):
1055
    self.precision.update_state(y_true, logits, sample_weight)
1056
    self.recall.update_state(y_true, logits, sample_weight)
1057

1058

1059
def build_metrics(mode):
1060
  """Creates list of metrics for all metric types and sections."""
1061
  if mode == 'only_new_entities':
1062
    sections = ['new_entity']
1063
  else:
1064
    sections = ['new_entity', 'entities', 'properties', 'membership']
1065

1066
  metrics = []
1067
  for section in sections:
1068
    metrics += [
1069
        ContrackAccuracy(section),
1070
        ContrackPrecision(section),
1071
        ContrackRecall(section),
1072
        ContrackF1Score(section)
1073
    ]
1074
  return metrics
1075

1076

1077
def get_custom_objects():
1078
  return {
1079
      'ContrackModel': ContrackModel,
1080
      'ConvertToSequenceLayer': ConvertToSequenceLayer,
1081
      'IdentifyNewEntityLayer': IdentifyNewEntityLayer,
1082
      'ComputeIdsLayer': ComputeIdsLayer,
1083
      'TrackEnrefsLayer': TrackEnrefsLayer,
1084
      'ContrackLoss': ContrackLoss,
1085
      'ContrackAccuracy': ContrackAccuracy,
1086
      'ContrackPrecision': ContrackPrecision,
1087
      'ContrackRecall': ContrackRecall,
1088
      'ContrackF1Score': ContrackF1Score,
1089
  }
1090

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.