google-research
1089 строк · 36.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""The actual Contrack model."""
17
18import logging19import math20from typing import Any, Callable, Dict, Iterator, List, Tuple21
22import tensorflow as tf23
24from contrack import custom_ops25from contrack import encoding26from contrack.env import Env27
28
29def _pad_and_clip_seq_batch(seq_batch, seq_len_batch,30pad_value, maxlen,31data_vec_len):32"""Pads a batch of sequences with a padding value up to a length."""33with tf.name_scope('pad_seq_batch'):34seq_mask = tf.sequence_mask(lengths=seq_len_batch, maxlen=maxlen)35seq_mask = tf.expand_dims(seq_mask, 2)36seq_mask = tf.tile(seq_mask, [1, 1, data_vec_len])37# Trim or pad seq_batch as needed to make the shapes compatible38padded_shape = [tf.shape(seq_batch)[0], maxlen, data_vec_len]39seq_batch = seq_batch[:, :maxlen, :]40seq_dim_pad_len = tf.constant(maxlen) - tf.shape(seq_batch)[1]41seq_batch = tf.pad(42seq_batch, paddings=[[0, 0], [0, seq_dim_pad_len], [0, 0]])43seq_batch.set_shape([seq_batch.shape[0], maxlen, data_vec_len])44pad_value = tf.cast(pad_value, dtype=seq_batch.dtype)45pad_batch = tf.fill(padded_shape, value=pad_value)46padded_seq_batch = tf.where(seq_mask, seq_batch, pad_batch)47return padded_seq_batch48
49
50def shape_list(x):51"""Return list of dims, statically where possible."""52x = tf.convert_to_tensor(x)53
54# If unknown rank, return dynamic shape55if x.get_shape().dims is None:56return tf.shape(x)57
58static = x.get_shape().as_list()59shape = tf.shape(x)60
61ret = []62for i, dim in enumerate(static):63if dim is None:64dim = shape[i]65ret.append(dim)66return ret67
68
69def split_heads(x, n):70x_shape = shape_list(x)71m = x_shape[-1]72if isinstance(m, int) and isinstance(n, int):73assert m % n == 074y = tf.reshape(x, x_shape[:-1] + [n, m // n])75
76return tf.transpose(y, [0, 2, 1, 3])77
78
79def combine_heads(x):80x = tf.transpose(x, [0, 2, 1, 3])81x_shape = shape_list(x)82a, b = x_shape[-2:]83return tf.reshape(x, x_shape[:-2] + [a * b])84
85
86class ConvertToSequenceLayer(tf.keras.layers.Layer):87"""Concatenates input data into a sequence suitable for prediction."""88
89def __init__(self, input_vec_len):90super(ConvertToSequenceLayer, self).__init__()91self.config = Env.get().config92self.input_vec_len = input_vec_len93
94@classmethod95def from_config(cls, config):96return ConvertToSequenceLayer(config['input_vec_len'])97
98def get_config(self):99return {'input_vec_len': self.input_vec_len}100
101def compute_mask(self, inputs, mask=None):102state_seq_len = inputs['state_seq_length']103token_seq_len = inputs['token_seq_length']104
105input_seq_len = tf.add(state_seq_len, token_seq_len)106
107return tf.sequence_mask(input_seq_len, maxlen=self.config.max_seq_len)108
109def call(self,110inputs,111training = None):112with tf.name_scope('convert_to_sequence'):113state_seq_len = tf.cast(inputs['state_seq_length'], tf.int32)114state_seq = inputs['state_seq']115
116token_seq_len = tf.cast(inputs['token_seq_length'], tf.int32)117token_seq = inputs['token_seq']118
119input_seq, input_seq_len = custom_ops.sequence_concat(120sequences=[state_seq, token_seq],121lengths=[state_seq_len, token_seq_len])122
123# Clip and pad seq124input_seq_len = tf.minimum(125input_seq_len, self.config.max_seq_len, name='input_seq_len')126input_seq = _pad_and_clip_seq_batch(127input_seq,128input_seq_len,129pad_value=0,130maxlen=self.config.max_seq_len,131data_vec_len=self.input_vec_len)132
133# Add timing signal134if self.config.timing_signal_size > 0:135num_channels = self.config.timing_signal_size136positions = tf.cast(tf.range(self.config.max_seq_len), dtype=tf.int64)137
138min_timescale = 1.0139max_timescale = 1.0e4140
141with tf.name_scope('TimingSignal'):142num_timescales = num_channels // 2143log_timescale_increment = (144math.log(max_timescale / min_timescale) /145(tf.cast(num_timescales, tf.float32) - 1))146inv_timescales = min_timescale * tf.exp(147tf.cast(tf.range(num_timescales), tf.float32) *148-log_timescale_increment)149scaled_time = (150tf.expand_dims(tf.cast(positions, tf.float32), 1) *151tf.expand_dims(inv_timescales, 0))152signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)153time_signal = tf.pad(154signal, [[0, 0], [0, tf.math.floormod(num_channels, 2)]])155
156time_signal = tf.expand_dims(time_signal, 0)157time_signal = tf.tile(time_signal, [tf.shape(input_seq)[0], 1, 1])158seq_mask = tf.cast(159tf.sequence_mask(160lengths=input_seq_len, maxlen=self.config.max_seq_len),161dtype=tf.float32)162seq_mask = tf.expand_dims(seq_mask, 2)163seq_mask = tf.tile(seq_mask, [1, 1, num_channels])164time_signal = time_signal * seq_mask165input_seq = tf.concat([input_seq, time_signal],166axis=2,167name='add_time_signal')168
169return input_seq, input_seq_len170
171
172class IdentifyNewEntityLayer(tf.keras.layers.Layer):173"""The layer identifying new entities in a message."""174
175def __init__(self, seq_shape):176super(IdentifyNewEntityLayer, self).__init__()177self.config = Env.get().config.new_id_attention178self.supports_masking = True179self.seq_shape = seq_shape180
181self.batch_size = seq_shape[0]182self.seq_len = seq_shape[1]183key_dim = math.ceil(seq_shape[2] / self.config.num_heads)184output_shape = [185self.batch_size, self.seq_len, key_dim * self.config.num_heads186]187self.attention = tf.keras.layers.MultiHeadAttention(188num_heads=self.config.num_heads,189key_dim=key_dim,190value_dim=key_dim,191use_bias=True,192dropout=self.config.dropout_rate,193output_shape=[output_shape[2]])194
195self.layer_norm = tf.keras.layers.LayerNormalization(196axis=2, epsilon=1e-6, name='SelfAttentionNorm')197
198self.affine = tf.keras.layers.Dense(2, use_bias=True)199
200self.q_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)201self.k_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)202self.v_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)203
204hidden_size = 100205filter_size = 800206self.attention_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)207self.match_residual_dense = tf.keras.layers.Dense(208hidden_size, use_bias=True)209self.ffn_layer_norm = tf.keras.layers.LayerNormalization(210axis=2, epsilon=1e-6, name='FFNNorm')211
212self.ff_relu_dense = tf.keras.layers.Dense(213filter_size, use_bias=True, activation='relu')214self.ff_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)215
216@classmethod217def from_config(cls, config):218return IdentifyNewEntityLayer(config['seq_shape'])219
220def get_config(self):221return {'seq_shape': self.seq_shape}222
223def compute_mask(self,224inputs,225mask = None):226return mask227
228def call(self,229inputs,230training = None,231mask = None):232x = inputs233
234# Make feature space size a multiple of num_heads235num_heads = self.config.num_heads236if x.shape[-1] % num_heads > 0:237with tf.name_scope('PadForMultipleOfHeads'):238fill_size = num_heads - x.shape[-1] % num_heads239fill_mat = tf.tile(tf.zeros_like(x[:, :, :1]), [1, 1, fill_size])240x = tf.concat([x, fill_mat], 2)241
242# Multihead Attention243input_depth = x.shape[-1]244x = self.layer_norm(x, training=training)245q = split_heads(self.q_dense(x, training=training), num_heads)246k = split_heads(self.k_dense(x, training=training), num_heads)247v = split_heads(self.v_dense(x, training=training), num_heads)248
249key_depth_per_head = input_depth // num_heads250q *= key_depth_per_head**-0.5251
252logits = tf.matmul(q, k, transpose_b=True)253weights = tf.nn.softmax(logits, name='attention_weights')254y = tf.matmul(weights, v)255
256y = combine_heads(y)257y = self.attention_dense(y)258
259r = self.match_residual_dense(x)260y = self.ffn_layer_norm(y + r, training=training)261
262# Feed forward263z = self.ff_relu_dense(y, training=training)264z = self.ff_dense(z, training=training)265z += y266
267# Affine layer268z = tf.concat([inputs[:, :, :68], z], 2)269logits = self.affine(z, training=training)270
271# Apply mask272logits *= tf.expand_dims(tf.cast(mask, tf.float32), -1)273
274return logits275
276
277class ComputeIdsLayer(tf.keras.layers.Layer):278"""Compute Ids for the (new entity) tokens in the input sequence."""279
280def __init__(self):281super(ComputeIdsLayer, self).__init__()282self.encodings = Env.get().encodings283self.config = Env.get().config284
285@classmethod286def from_config(cls, config):287del config288return ComputeIdsLayer()289
290def get_config(self):291return {}292
293def compute_mask(self,294inputs,295mask = None):296_, seq_len, _ = inputs297return tf.sequence_mask(seq_len, maxlen=self.config.max_seq_len)298
299def call(self,300inputs,301mask = None):302seq, enref_seq_len, is_new_logits = inputs303
304enref_seq_len = tf.cast(enref_seq_len, dtype=tf.int32)305
306enref_ids = self.encodings.as_enref_encoding(seq).enref_id.slice()307
308is_new_entity = tf.cast(is_new_logits[:, :, 0] > 0.0, tf.float32)309
310new_id_one_hot = custom_ops.new_id(311state_ids=enref_ids, state_len=enref_seq_len, is_new=is_new_entity)312new_id_one_hot = tf.stop_gradient(new_id_one_hot)313
314return new_id_one_hot315
316
317class TrackEnrefsLayer(tf.keras.layers.Layer):318"""Predict enref Ids, properties, ane group membership."""319
320def __init__(self, seq_shape):321super(TrackEnrefsLayer, self).__init__()322
323self.config = Env.get().config.tracking_attention324self.encodings = Env.get().encodings325self.supports_masking = True326self.seq_shape = seq_shape327
328self.batch_size = seq_shape[0]329self.seq_len = seq_shape[1]330attention_input_length = (331seq_shape[2] + 2 + self.encodings.new_enref_encoding().enref_id.SIZE)332key_dim = math.ceil(attention_input_length / self.config.num_heads)333output_shape = [334self.batch_size, self.seq_len, key_dim * self.config.num_heads335]336self.attention = tf.keras.layers.MultiHeadAttention(337num_heads=self.config.num_heads,338key_dim=key_dim,339value_dim=key_dim,340use_bias=True,341dropout=self.config.dropout_rate,342output_shape=[output_shape[2]])343
344self.layer_norm = tf.keras.layers.LayerNormalization(345axis=2, epsilon=1e-6, name='SelfAttentionNorm')346
347self.affine = tf.keras.layers.Dense(348self.encodings.prediction_encoding_length, use_bias=True)349
350self.q_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)351self.k_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)352self.v_dense = tf.keras.layers.Dense(output_shape[2], use_bias=True)353
354hidden_size = 100355filter_size = 800356self.attention_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)357self.match_residual_dense = tf.keras.layers.Dense(358hidden_size, use_bias=True)359self.ffn_layer_norm = tf.keras.layers.LayerNormalization(360axis=2, epsilon=1e-6, name='FFNNorm')361
362self.ff_relu_dense = tf.keras.layers.Dense(363filter_size, use_bias=True, activation='relu')364self.ff_dense = tf.keras.layers.Dense(hidden_size, use_bias=True)365
366@classmethod367def from_config(cls, config):368return IdentifyNewEntityLayer(config['seq_shape'])369
370def get_config(self):371return {'seq_shape': self.seq_shape}372
373def compute_mask(self,374inputs,375mask = None):376return mask[0]377
378def call(self,379inputs,380training = None,381mask = None):382seq, is_new_entity, new_ids = inputs383
384is_new_entity = tf.stop_gradient(is_new_entity)385
386x = tf.concat([seq, is_new_entity, new_ids], axis=2)387
388# Make feature space size a multiple of num_heads389num_heads = self.config.num_heads390if x.shape[-1] % num_heads > 0:391with tf.name_scope('PadForMultipleOfHeads'):392fill_size = num_heads - x.shape[-1] % num_heads393fill_mat = tf.tile(tf.zeros_like(x[:, :, :1]), [1, 1, fill_size])394x = tf.concat([x, fill_mat], 2)395
396# Multihead Attention397input_depth = x.shape[-1]398x = self.layer_norm(x, training=training)399q = split_heads(self.q_dense(x, training=training), num_heads)400k = split_heads(self.k_dense(x, training=training), num_heads)401v = split_heads(self.v_dense(x, training=training), num_heads)402
403key_depth_per_head = input_depth // num_heads404q *= key_depth_per_head**-0.5405
406logits = tf.matmul(q, k, transpose_b=True)407weights = tf.nn.softmax(logits, name='attention_weights')408y = tf.matmul(weights, v)409
410y = combine_heads(y)411y = self.attention_dense(y)412
413r = self.match_residual_dense(x)414y = self.ffn_layer_norm(y + r, training=training)415
416# Feed forward417z = self.ff_relu_dense(y, training=training)418z = self.ff_dense(z, training=training)419z += y420
421# Affine layer422logits = self.affine(z, training=training)423
424# Apply mask425logits *= tf.expand_dims(tf.cast(mask[0], tf.float32), -1)426
427return logits428
429
430class MergeIdsLayer(tf.keras.layers.Layer):431"""Layer merging the ids from new entities and existing entities."""432
433def __init__(self):434super(MergeIdsLayer, self).__init__()435
436self.encodings = Env.get().encodings437
438@classmethod439def from_config(cls, config):440del config441return MergeIdsLayer()442
443def get_config(self):444return {}445
446def call(self,447inputs,448training = None,449mask = None):450is_new_entity, new_ids, logits = inputs451
452logits_encoding = self.encodings.as_prediction_encoding(logits)453existing_ids = logits_encoding.enref_id.slice()454
455is_new_id = tf.cast(is_new_entity > 0.0, tf.float32)456is_new_id = tf.reduce_max(is_new_id, 2, keepdims=True)457ids = is_new_id * new_ids458ids += (1.0 - is_new_id) * existing_ids459
460logits = logits_encoding.enref_id.replace(ids)461logits = self.encodings.as_prediction_encoding(462logits).enref_meta.replace_is_new_slice(is_new_entity)463
464return logits465
466
467class ContrackModel(tf.keras.Model):468"""The Contrack model."""469
470def __init__(self, mode, print_predictions = False):471super(ContrackModel, self).__init__()472
473self.config = Env.get().config474self.encodings = Env.get().encodings475self.mode = mode476self.print_predictions = print_predictions477self.teacher_forcing = True478
479self.convert_to_sequence_layer = ConvertToSequenceLayer(480self.encodings.enref_encoding_length)481
482input_shape = [483self.config.batch_size, self.config.max_seq_len,484self.encodings.enref_encoding_length + self.config.timing_signal_size485]486self.new_entity_layer = IdentifyNewEntityLayer(input_shape)487
488self.compute_ids_layer = ComputeIdsLayer()489
490self.track_enrefs_layer = TrackEnrefsLayer(input_shape)491
492self.merge_ids_layer = MergeIdsLayer()493
494@classmethod495def from_config(cls, config):496return ContrackModel(config['mode'])497
498def get_config(self):499return {'mode': self.mode}500
501def init_weights_from_new_entity_model(self, model):502# Call the model once to create weights503input_vec_shape = [504self.config.batch_size, self.config.max_seq_len,505self.encodings.token_encoding_length506]507null_input = {508'state_seq_length': tf.ones([self.config.batch_size]),509'state_seq': tf.zeros(input_vec_shape),510'token_seq_length': tf.ones([self.config.batch_size]),511'token_seq': tf.zeros(input_vec_shape)512}513self(null_input)514
515# Then copy over layer weights516self.new_entity_layer.set_weights(model.new_entity_layer.get_weights())517
518def call(self,519inputs,520training = False):521# Step 1: Concatenate input data into a single sequence522seq, _ = self.convert_to_sequence_layer(inputs)523
524# Step 2: Identify new entities.525is_new_entity = self.new_entity_layer(seq)526
527if self.mode == 'only_new_entities':528res = tf.zeros_like(seq[:, :, :self.encodings.prediction_encoding_length])529res_enc = self.encodings.as_prediction_encoding(res)530res = res_enc.enref_meta.replace_is_new_slice(is_new_entity)531return res532elif self.mode == 'only_tracking':533is_new_entity = tf.stop_gradient(is_new_entity)534
535# Step 3: Compute enref ids for new entities536new_ids = self.compute_ids_layer(537(seq, inputs['state_seq_length'], is_new_entity))538
539# Step 4: Determine enref predictions540logits = self.track_enrefs_layer((seq, is_new_entity, new_ids))541
542# Step 5: Merge ids from new and existing enrefs543logits = self.merge_ids_layer((is_new_entity, new_ids, logits))544
545return logits546
547def train_step(self, data):548"""The training step."""549x = data550
551# Shift true labels seq to align with tokens in input_seq552state_seq_len = tf.cast(data['state_seq_length'], tf.int32)553token_seq_len = tf.cast(data['token_seq_length'], tf.int32)554state_seq_dims = tf.shape(data['state_seq'])555enref_padding = tf.zeros([556state_seq_dims[0], state_seq_dims[1],557self.encodings.prediction_encoding_length558])559y, y_len = custom_ops.sequence_concat(560sequences=[enref_padding, data['annotation_seq']],561lengths=[state_seq_len, token_seq_len])562
563# Clip and pad true labels seq564y_len = tf.minimum(y_len, self.config.max_seq_len, name='y_len')565y = _pad_and_clip_seq_batch(566y,567y_len,568pad_value=0,569maxlen=self.config.max_seq_len,570data_vec_len=self.encodings.prediction_encoding_length)571
572input_seq_len = tf.add(state_seq_len, token_seq_len)573seq_mask = tf.sequence_mask(574input_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)575enref_mask = tf.sequence_mask(576state_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)577sample_weight = seq_mask - enref_mask578
579with tf.GradientTape() as tape:580y_pred = self(x, training=True)581loss = self.compiled_loss(582y, y_pred, sample_weight, regularization_losses=self.losses)583self.optimizer.minimize(loss, self.trainable_variables, tape=tape)584self.compiled_metrics.update_state(y, y_pred, sample_weight)585return {m.name: m.result() for m in self.metrics}586
587def predict_step(588self, data):589"""The logic for one inference step."""590if not self.teacher_forcing:591data, y_pred = self.call_without_teacher_forcing(data)592else:593y_pred = self(data, training=False)594
595x = {596'state_seq_length': data['state_seq_length'],597'token_seq_length': data['token_seq_length'],598'scenario_id': data['scenario_id']599}600x['state_seq'] = _pad_and_clip_seq_batch(601data['state_seq'],602data['state_seq_length'],603pad_value=0,604maxlen=self.config.max_seq_len,605data_vec_len=self.encodings.enref_encoding_length)606x['token_seq'] = _pad_and_clip_seq_batch(607data['token_seq'],608data['token_seq_length'],609pad_value=0,610maxlen=self.config.max_seq_len,611data_vec_len=self.encodings.token_encoding_length)612x['word_seq'] = _pad_and_clip_seq_batch(613tf.expand_dims(data['word_seq'], -1),614data['token_seq_length'],615pad_value='',616maxlen=self.config.max_seq_len,617data_vec_len=1)618x['annotation_seq'] = _pad_and_clip_seq_batch(619data['annotation_seq'],620data['token_seq_length'],621pad_value=0,622maxlen=self.config.max_seq_len,623data_vec_len=self.encodings.prediction_encoding_length)624
625return (x, y_pred)626
627def make_test_function(self):628"""Creates a function that executes one step of evaluation."""629test_fn = super(ContrackModel, self).make_test_function()630
631def adapted_test_fn(iterator):632outputs = test_fn(iterator)633if 'print_prediction' in outputs:634pred_msgs = outputs['print_prediction']635if self.print_predictions:636logging.info(pred_msgs.numpy().decode('utf-8'))637del outputs['print_prediction']638return outputs639
640self.test_function = adapted_test_fn641return adapted_test_fn642
643def print_prediction(self, seq_len, state_seq_len,644words, tokens,645predictions, true_targets):646res = ''647
648for batch_index, num_token in enumerate(seq_len.numpy()):649res += '---------------------------------------\n'650for i in range(num_token):651word = words[batch_index, i].numpy().decode('utf-8')652res += word + ': '653
654seq_index = state_seq_len[batch_index] + i655if seq_index >= self.config.max_seq_len:656break657
658true_target = self.encodings.as_prediction_encoding(659true_targets[batch_index, seq_index, :].numpy())660
661pred = self.encodings.as_prediction_encoding(662predictions[batch_index, seq_index, :].numpy())663
664if self.mode == 'only_new_entities':665true_label = '%s%s' % (666'n' if true_target.enref_meta.is_new() > 0.0 else '',667'c' if true_target.enref_meta.is_new_continued() > 0.0 else '')668predicted_label = '%s%s' % (669'n' if pred.enref_meta.is_new() > 0.0 else '',670'c' if pred.enref_meta.is_new_continued() > 0.0 else '')671if true_label != predicted_label:672res += '*** %s != %s' % (predicted_label, true_label)673res += ' ' + str(pred.enref_meta.slice())674else:675res += predicted_label676else:677token = self.encodings.as_token_encoding(tokens[batch_index, i, :])678true_enref = self.encodings.build_enref_from_prediction(679token, true_target)680pred_enref = self.encodings.build_enref_from_prediction(token, pred)681if str(true_enref) != str(pred_enref):682res += '*** %s != %s' % (str(pred_enref), str(true_enref))683res += ' %s' % str([684round(a, 2)685for a in true_targets[batch_index, seq_index, :].numpy()686])687else:688res += str(pred_enref) if pred_enref is not None else ''689
690res += '\n'691
692return res693
694def test_step(self, data):695"""The logic for one evaluation step."""696x = data697
698# Shift true labels seq to align with tokens in input_seq699state_seq_len = tf.cast(data['state_seq_length'], tf.int32)700token_seq_len = tf.cast(data['token_seq_length'], tf.int32)701state_seq_dims = tf.shape(data['state_seq'])702enref_padding = tf.zeros([703state_seq_dims[0], state_seq_dims[1],704self.encodings.prediction_encoding_length705])706y, y_len = custom_ops.sequence_concat(707sequences=[enref_padding, data['annotation_seq']],708lengths=[state_seq_len, token_seq_len])709
710# Clip and pad true labels seq711y_len = tf.minimum(y_len, self.config.max_seq_len, name='y_len')712y = _pad_and_clip_seq_batch(713y,714y_len,715pad_value=0,716maxlen=self.config.max_seq_len,717data_vec_len=self.encodings.prediction_encoding_length)718
719input_seq_len = tf.add(state_seq_len, token_seq_len)720seq_mask = tf.sequence_mask(721input_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)722enref_mask = tf.sequence_mask(723state_seq_len, maxlen=self.config.max_seq_len, dtype=tf.float32)724sample_weight = seq_mask - enref_mask725
726y_pred = self(x, training=False)727# Updates stateful loss metrics.728self.compiled_loss(729y, y_pred, sample_weight, regularization_losses=self.losses)730
731self.compiled_metrics.update_state(y, y_pred, sample_weight)732
733# Print prediction to log734output_tensors = {m.name: m.result() for m in self.metrics}735
736print_prediction_fn = tf.py_function(self.print_prediction, [737x['token_seq_length'], x['state_seq_length'], x['word_seq'],738x['token_seq'], y_pred, y739], tf.string)740
741output_tensors['print_prediction'] = print_prediction_fn742return output_tensors743
744def disable_teacher_forcing(self):745self.teacher_forcing = False746self.current_scenario = None747self.current_enrefs = []748self.current_participants = []749assert self.config.batch_size == 1750
751def call_without_teacher_forcing(self, data):752scenario_id = data['scenario_id'][0].numpy().decode('utf-8')753# logging.info(scenario_id)754if scenario_id == self.current_scenario:755# Continue existing conversations, create state_seq from enrefs756enrefs = self.current_enrefs757logging.info('Continue conversation with %d enrefs', len(enrefs))758
759data['state_seq_length'] = tf.constant([len(enrefs)], dtype=tf.int64)760sender = data['sender'][0].numpy().decode('utf-8')761for enref in enrefs:762entity_name = enref.entity_name763enref.enref_context.set_is_sender(entity_name == sender)764enref.enref_context.set_is_recipient(765entity_name != sender and entity_name in self.current_participants)766enref.enref_context.set_message_offset(767enref.enref_context.get_message_offset() + 1)768
769# for i, e in enumerate(enrefs):770# diff = e.array - data['state_seq'][0, i, :].numpy()771# if np.sum(np.abs(diff)) > 0.1:772# logging.info('diff for %s: %s', str(e), diff.tolist())773
774enref_seq = [e.array for e in enrefs]775data['state_seq'] = tf.constant([enref_seq], dtype=tf.float32)776else:777# Start new conversation, obtain initial enrefs from participants778logging.info('New conversation, participants %s',779data['participants'].values)780self.current_scenario = scenario_id781self.current_participants = [782p.numpy().decode('utf-8') for p in data['participants'].values783]784self.current_enrefs = []785for i in range(0, data['state_seq_length'][0].numpy()):786enref_array = data['state_seq'][0, i].numpy()787enref = self.encodings.as_enref_encoding(enref_array)788enref_name = self.current_participants[i]789enref.populate(enref_name, (i, i + 1), enref_name)790self.current_enrefs.append(enref)791# logging.info('Enrefs: %s', str(self.current_enrefs))792
793# Run model794y_pred = self(data, training=False)795
796# Update set of enrefs from prediction797num_tokens = len(data['word_seq'][0])798num_enrefs = len(data['state_seq'][0])799
800token_encs = [self.encodings.as_token_encoding(801data['token_seq'][0, i, :].numpy()) for i in range(0, num_tokens)]802pred_encs = [self.encodings.as_prediction_encoding(y_pred[0, i, :].numpy())803for i in range(num_enrefs, min(num_enrefs + num_tokens,804self.config.max_seq_len))]805words = [data['word_seq'][0, i].numpy().decode('utf-8')806for i in range(0, num_tokens)]807enrefs = self.encodings.build_enrefs_from_predictions(808token_encs, pred_encs, words, self.current_enrefs)809# logging.info('New Enrefs for %s: %s', words, enrefs)810logging.info('%d new enrefs', len(enrefs))811self.current_enrefs += enrefs812
813return (data, y_pred)814
815
816class ContrackLoss(tf.keras.losses.Loss):817"""The loss function used for contrack training."""818
819def __init__(self, mode):820super(821ContrackLoss,822self).__init__(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)823self.encodings = Env.get().encodings824self.config = Env.get().config825self.mode = mode826
827@classmethod828def from_config(cls, config):829return ContrackLoss(config['mode'])830
831def get_config(self):832return {'mode': self.mode}833
834def _compute_hinge_losses(self, labels,835logits):836"""Computes hinge loss."""837all_ones = tf.ones_like(labels)838labels = tf.math.subtract(2 * labels, all_ones)839losses = tf.nn.relu(840tf.math.subtract(all_ones, tf.math.multiply(labels, logits)))841if len(losses.get_shape()) > 2:842losses = tf.math.reduce_sum(losses, [2])843return losses844
845def _compute_annotation_losses(self, target,846predicted):847"""Compute loss comparing predicted and actual annotations."""848target_enc = self.encodings.as_prediction_encoding(target)849predicted_enc = self.encodings.as_prediction_encoding(predicted)850
851# Membership loss only for groups852loss = self._compute_hinge_losses(853labels=target_enc.enref_membership.slice(),854logits=predicted_enc.enref_membership.slice())855loss *= target_enc.enref_properties.is_group()856
857# Properties loss858loss += self._compute_hinge_losses(859labels=target_enc.enref_properties.slice(),860logits=predicted_enc.enref_properties.slice())861
862# Entity ID loss863is_new = (864target_enc.enref_meta.is_new() +865target_enc.enref_meta.is_new_continued())866
867existing_entity_loss = self._compute_hinge_losses(868labels=target_enc.enref_id.slice(),869logits=predicted_enc.enref_id.slice())870existing_entity_loss *= 1.0 - is_new871
872new_entity_loss = self._compute_hinge_losses(873labels=target_enc.enref_meta.get_is_new_slice(),874logits=predicted_enc.enref_meta.get_is_new_slice())875new_entity_loss *= is_new876fp_cost = self.config.new_id_false_negative_cost - 1.0877new_entity_loss *= tf.ones_like(new_entity_loss) + fp_cost * is_new878
879loss += existing_entity_loss + new_entity_loss880
881# Is_entity loss882not_an_entity_loss = self._compute_hinge_losses(883labels=target_enc.enref_meta.slice(),884logits=predicted_enc.enref_meta.slice())885loss *= target_enc.enref_meta.is_enref()886loss += not_an_entity_loss887
888return loss889
890def _compute_new_id_losses(self, target,891predicted):892"""Computes the new id losses for each token in each turn."""893target_meta = self.encodings.as_prediction_encoding(target).enref_meta894predicted_meta = self.encodings.as_prediction_encoding(predicted).enref_meta895
896losses = self._compute_hinge_losses(897labels=target_meta.get_is_new_slice(),898logits=predicted_meta.get_is_new_slice())899fn_cost = self.config.new_id_false_negative_cost - 1.0900new_entity_positives = target_meta.is_new() + target_meta.is_new_continued()901losses *= tf.ones_like(losses) + fn_cost * new_entity_positives902return losses903
904def call(self, target, predicted):905"""Compute loss comparing predicted and actual annotations."""906with tf.name_scope('contrack_loss'):907if self.mode == 'only_new_entities':908return self._compute_new_id_losses(target, predicted)909else:910return self._compute_annotation_losses(target, predicted)911
912
913def _get_named_slices(y_true, logits,914section_name):915"""Returns the slices (given by name) of true and predictied vector."""916is_entity = tf.expand_dims(y_true.enref_meta.is_enref(), 2)917if section_name == 'new_entity':918return (y_true.enref_meta.get_is_new_slice(),919is_entity * logits.enref_meta.get_is_new_slice())920elif section_name == 'entities':921return (y_true.enref_id.slice(), is_entity * logits.enref_id.slice())922elif section_name == 'properties':923return (y_true.enref_properties.slice(),924is_entity * logits.enref_properties.slice())925elif section_name == 'membership':926is_group = tf.expand_dims(y_true.enref_properties.is_group(), 2)927return (y_true.enref_membership.slice(),928is_entity * is_group * logits.enref_membership.slice())929else:930raise ValueError('Unknown section name %s' % section_name)931
932
933class ContrackAccuracy(tf.keras.metrics.Mean):934"""Computes zero-one accuracy on a given slice of the result vector."""935
936def __init__(self, section_name, dtype=None):937self.encodings = Env.get().encodings938self.section_name = section_name939super(ContrackAccuracy, self).__init__(940name=f'{section_name}/accuracy', dtype=dtype)941
942@classmethod943def from_config(cls, config):944return ContrackAccuracy(config['section_name'])945
946def get_config(self):947return {'section_name': self.section_name}948
949def update_state(self,950y_true,951logits,952sample_weight = None):953y_true, logits = _get_named_slices(954self.encodings.as_prediction_encoding(y_true),955self.encodings.as_prediction_encoding(logits), self.section_name)956y_pred = tf.cast(logits > 0.0, tf.float32)957
958matches = tf.reduce_max(tf.cast(y_true == y_pred, tf.float32), -1)959
960super(ContrackAccuracy, self).update_state(matches, sample_weight)961
962
963class ContrackPrecision(tf.keras.metrics.Precision):964"""Computes precision on a given slice of the result vector."""965
966def __init__(self, section_name, dtype=None):967self.encodings = Env.get().encodings968self.section_name = section_name969super(ContrackPrecision, self).__init__(970name=f'{section_name}/precision', dtype=dtype)971
972@classmethod973def from_config(cls, config):974return ContrackPrecision(config['section_name'])975
976def get_config(self):977return {'section_name': self.section_name}978
979def update_state(self,980y_true,981logits,982sample_weight = None):983y_true, logits = _get_named_slices(984self.encodings.as_prediction_encoding(y_true),985self.encodings.as_prediction_encoding(logits), self.section_name)986y_pred = tf.cast(logits > 0.0, tf.float32)987
988super(ContrackPrecision, self).update_state(y_true, y_pred, sample_weight)989
990
991class ContrackRecall(tf.keras.metrics.Recall):992"""Computes recall on a given slice of the result vector."""993
994def __init__(self, section_name, dtype=None):995self.encodings = Env.get().encodings996self.section_name = section_name997super(ContrackRecall, self).__init__(998name=f'{section_name}/recall', dtype=dtype)999
1000@classmethod1001def from_config(cls, config):1002return ContrackRecall(config['section_name'])1003
1004def get_config(self):1005return {'section_name': self.section_name}1006
1007def update_state(self,1008y_true,1009logits,1010sample_weight = None):1011y_true, logits = _get_named_slices(1012self.encodings.as_prediction_encoding(y_true),1013self.encodings.as_prediction_encoding(logits), self.section_name)1014y_pred = tf.cast(logits > 0.0, tf.float32)1015
1016super(ContrackRecall, self).update_state(y_true, y_pred, sample_weight)1017
1018
1019class ContrackF1Score(tf.keras.metrics.Metric):1020"""Computes the f1 score on a given slice of the result vector."""1021
1022def __init__(self, section_name, dtype=None):1023self.encodings = Env.get().encodings1024self.section_name = section_name1025self.precision = ContrackPrecision(section_name, dtype=dtype)1026self.recall = ContrackRecall(section_name, dtype=dtype)1027super(ContrackF1Score, self).__init__(1028name=f'{section_name}/f1score', dtype=dtype)1029
1030@classmethod1031def from_config(cls, config):1032return ContrackF1Score(config['section_name'])1033
1034def get_config(self):1035return {'section_name': self.section_name}1036
1037def add_weight(self, **kwargs):1038self.precision.add_weight(**kwargs)1039self.recall.add_weight(**kwargs)1040
1041def reset_states(self):1042self.precision.reset_states()1043self.recall.reset_states()1044
1045def result(self):1046precision = self.precision.result()1047recall = self.recall.result()1048return 2.0 * (precision * recall) / (precision + recall +1049tf.keras.backend.epsilon())1050
1051def update_state(self,1052y_true,1053logits,1054sample_weight = None):1055self.precision.update_state(y_true, logits, sample_weight)1056self.recall.update_state(y_true, logits, sample_weight)1057
1058
1059def build_metrics(mode):1060"""Creates list of metrics for all metric types and sections."""1061if mode == 'only_new_entities':1062sections = ['new_entity']1063else:1064sections = ['new_entity', 'entities', 'properties', 'membership']1065
1066metrics = []1067for section in sections:1068metrics += [1069ContrackAccuracy(section),1070ContrackPrecision(section),1071ContrackRecall(section),1072ContrackF1Score(section)1073]1074return metrics1075
1076
1077def get_custom_objects():1078return {1079'ContrackModel': ContrackModel,1080'ConvertToSequenceLayer': ConvertToSequenceLayer,1081'IdentifyNewEntityLayer': IdentifyNewEntityLayer,1082'ComputeIdsLayer': ComputeIdsLayer,1083'TrackEnrefsLayer': TrackEnrefsLayer,1084'ContrackLoss': ContrackLoss,1085'ContrackAccuracy': ContrackAccuracy,1086'ContrackPrecision': ContrackPrecision,1087'ContrackRecall': ContrackRecall,1088'ContrackF1Score': ContrackF1Score,1089}1090