google-research
917 строк · 38.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Baseline model for Schema-guided Dialogue State Tracking.
17
18Adapted from
19https://github.com/google-research/bert/blob/master/run_classifier.py
20"""
21
22from __future__ import absolute_import23from __future__ import division24from __future__ import print_function25
26import collections27import os28
29import numpy as np30import tensorflow.compat.v1 as tf31from tensorflow.compat.v1 import estimator as tf_estimator32
33
34from schema_guided_dst import schema35from schema_guided_dst.baseline import config36from schema_guided_dst.baseline import data_utils37from schema_guided_dst.baseline import extract_schema_embedding38from schema_guided_dst.baseline import pred_utils39from schema_guided_dst.baseline.bert import modeling40from schema_guided_dst.baseline.bert import optimization41from schema_guided_dst.baseline.bert import tokenization42
43
44flags = tf.compat.v1.flags45FLAGS = flags.FLAGS46
47# BERT based utterance encoder related flags.
48flags.DEFINE_string("bert_ckpt_dir", None,49"Directory containing pre-trained BERT checkpoint.")50
51flags.DEFINE_bool(52"do_lower_case", False,53"Whether to lower case the input text. Should be True for uncased "54"models and False for cased models.")55
56flags.DEFINE_integer(57"max_seq_length", 80,58"The maximum total input sequence length after WordPiece tokenization. "59"Sequences longer than this will be truncated, and sequences shorter "60"than this will be padded.")61
62flags.DEFINE_float("dropout_rate", 0.1,63"Dropout rate for BERT representations.")64
65# Hyperparameters and optimization related flags.
66flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")67
68flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")69
70flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")71
72flags.DEFINE_float("learning_rate", 1e-4, "The initial learning rate for Adam.")73
74flags.DEFINE_float("num_train_epochs", 80.0,75"Total number of training epochs to perform.")76flags.DEFINE_float(77"warmup_proportion", 0.1,78"Proportion of training to perform linear learning rate warmup for. "79"E.g., 0.1 = 10% of training.")80
81flags.DEFINE_integer("save_checkpoints_steps", 1000,82"How often to save the model checkpoint.")83
84flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")85
86flags.DEFINE_string(87"tpu_name", None,88"The Cloud TPU to use for training. This should be either the name "89"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "90"url.")91
92flags.DEFINE_string(93"tpu_zone", None,94"[Optional] GCE zone where the Cloud TPU is located in. If not "95"specified, we will attempt to automatically detect the GCE project from "96"metadata.")97
98flags.DEFINE_string(99"gcp_project", None,100"[Optional] Project name for the Cloud TPU-enabled project. If not "101"specified, we will attempt to automatically detect the GCE project from "102"metadata.")103
104flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")105
106flags.DEFINE_integer(107"num_tpu_cores", 8,108"Only used if `use_tpu` is True. Total number of TPU cores to use.")109
110flags.DEFINE_bool(111"use_one_hot_embeddings", False,112"If True, tf.one_hot will be used for embedding lookups, otherwise "113"tf.nn.embedding_lookup will be used. On TPUs, this should be True "114"since it is much faster.")115
116# Input and output paths and other flags.
117flags.DEFINE_enum("task_name", None, config.DATASET_CONFIG.keys(),118"The name of the task to train.")119
120flags.DEFINE_string(121"dstc8_data_dir", None,122"Directory for the downloaded DSTC8 data, which contains the dialogue files"123" and schema files of all datasets (eg train, dev)")124
125flags.DEFINE_enum("run_mode", None, ["train", "predict"],126"The mode to run the script in.")127
128flags.DEFINE_string(129"output_dir", None,130"The output directory where the model checkpoints will be written.")131
132flags.DEFINE_string(133"schema_embedding_dir", None,134"Directory where .npy file for embedding of entities (slots, values,"135" intents) in the dataset_split's schema are stored.")136
137flags.DEFINE_string(138"dialogues_example_dir", None,139"Directory where tf.record of DSTC8 dialogues data are stored.")140
141flags.DEFINE_enum("dataset_split", None, ["train", "dev", "test"],142"Dataset split for training / prediction.")143
144flags.DEFINE_string(145"eval_ckpt", "",146"Comma separated numbers, each being a step number of model checkpoint"147" which makes predictions.")148
149flags.DEFINE_bool(150"overwrite_dial_file", False,151"Whether to generate a new Tf.record file saving the dialogue examples.")152
153flags.DEFINE_bool(154"overwrite_schema_emb_file", False,155"Whether to generate a new schema_emb file saving the schemas' embeddings.")156
157flags.DEFINE_bool(158"log_data_warnings", False,159"If True, warnings created using data processing are logged.")160
161
162# Modified from run_classifier.file_based_input_fn_builder
163def _file_based_input_fn_builder(dataset_config, input_dial_file,164schema_embedding_file, is_training,165drop_remainder):166"""Creates an `input_fn` closure to be passed to TPUEstimator."""167
168max_num_cat_slot = dataset_config.max_num_cat_slot169max_num_noncat_slot = dataset_config.max_num_noncat_slot170max_num_total_slot = max_num_cat_slot + max_num_noncat_slot171max_num_intent = dataset_config.max_num_intent172max_utt_len = FLAGS.max_seq_length173
174name_to_features = {175"example_id":176tf.io.FixedLenFeature([], tf.string),177"is_real_example":178tf.io.FixedLenFeature([], tf.int64),179"service_id":180tf.io.FixedLenFeature([], tf.int64),181"utt":182tf.io.FixedLenFeature([max_utt_len], tf.int64),183"utt_mask":184tf.io.FixedLenFeature([max_utt_len], tf.int64),185"utt_seg":186tf.io.FixedLenFeature([max_utt_len], tf.int64),187"cat_slot_num":188tf.io.FixedLenFeature([], tf.int64),189"cat_slot_status":190tf.io.FixedLenFeature([max_num_cat_slot], tf.int64),191"cat_slot_value_num":192tf.io.FixedLenFeature([max_num_cat_slot], tf.int64),193"cat_slot_value":194tf.io.FixedLenFeature([max_num_cat_slot], tf.int64),195"noncat_slot_num":196tf.io.FixedLenFeature([], tf.int64),197"noncat_slot_status":198tf.io.FixedLenFeature([max_num_noncat_slot], tf.int64),199"noncat_slot_value_start":200tf.io.FixedLenFeature([max_num_noncat_slot], tf.int64),201"noncat_slot_value_end":202tf.io.FixedLenFeature([max_num_noncat_slot], tf.int64),203"noncat_alignment_start":204tf.io.FixedLenFeature([max_utt_len], tf.int64),205"noncat_alignment_end":206tf.io.FixedLenFeature([max_utt_len], tf.int64),207"req_slot_num":208tf.io.FixedLenFeature([], tf.int64),209"req_slot_status":210tf.io.FixedLenFeature([max_num_total_slot], tf.int64),211"intent_num":212tf.io.FixedLenFeature([], tf.int64),213"intent_status":214tf.io.FixedLenFeature([max_num_intent], tf.int64),215}216with tf.io.gfile.GFile(schema_embedding_file, "rb") as f:217schema_data = np.load(f, allow_pickle=True)218
219# Convert from list of dict to dict of list220schema_data_dict = collections.defaultdict(list)221for service in schema_data:222schema_data_dict["cat_slot_emb"].append(service["cat_slot_emb"])223schema_data_dict["cat_slot_value_emb"].append(service["cat_slot_value_emb"])224schema_data_dict["noncat_slot_emb"].append(service["noncat_slot_emb"])225schema_data_dict["req_slot_emb"].append(service["req_slot_emb"])226schema_data_dict["intent_emb"].append(service["intent_emb"])227
228def _decode_record(record, name_to_features, schema_tensors):229"""Decodes a record to a TensorFlow example."""230
231example = tf.parse_single_example(record, name_to_features)232
233# tf.Example only supports tf.int64, but the TPU only supports tf.int32.234# So cast all int64 to int32.235for name in list(example.keys()):236t = example[name]237if t.dtype == tf.int64:238t = tf.cast(t, tf.int32)239example[name] = t240
241# Here we need to insert schema's entity embedding to each example.242
243# Shapes for reference: (all have type tf.float32)244# "cat_slot_emb": [max_num_cat_slot, hidden_dim]245# "cat_slot_value_emb": [max_num_cat_slot, max_num_value, hidden_dim]246# "noncat_slot_emb": [max_num_noncat_slot, hidden_dim]247# "req_slot_emb": [max_num_total_slot, hidden_dim]248# "intent_emb": [max_num_intent, hidden_dim]249
250service_id = example["service_id"]251for key, value in schema_tensors.items():252example[key] = value[service_id]253return example254
255def input_fn(params):256"""The actual input function."""257batch_size = params["batch_size"]258
259# For training, we want a lot of parallel reading and shuffling.260# For eval, we want no shuffling and parallel reading doesn't matter.261d = tf.data.TFRecordDataset(input_dial_file)262# Uncomment for debugging263# d = d.take(12)264if is_training:265d = d.repeat()266d = d.shuffle(buffer_size=100)267schema_tensors = {}268for key, array in schema_data_dict.items():269schema_tensors[key] = tf.convert_to_tensor(np.asarray(array, np.float32))270
271d = d.apply(272tf.data.experimental.map_and_batch(273lambda rec: _decode_record(rec, name_to_features, schema_tensors),274batch_size=batch_size,275drop_remainder=drop_remainder))276return d277
278return input_fn279
280
281class SchemaGuidedDST(object):282"""Baseline model for schema guided dialogue state tracking."""283
284def __init__(self, bert_config, use_one_hot_embeddings):285self._bert_config = bert_config286self._use_one_hot_embeddings = use_one_hot_embeddings287
288def define_model(self, features, is_training):289"""Define the model computation.290
291Args:
292features: A dict mapping feature names to corresponding tensors.
293is_training: A boolean which is True when the model is being trained.
294
295Returns:
296outputs: A dict mapping output names to corresponding tensors.
297"""
298# Encode the utterances using BERT.299self._encoded_utterance, self._encoded_tokens = (300self._encode_utterances(features, is_training))301outputs = {}302outputs["logit_intent_status"] = self._get_intents(features)303outputs["logit_req_slot_status"] = self._get_requested_slots(features)304cat_slot_status, cat_slot_value = self._get_categorical_slot_goals(features)305outputs["logit_cat_slot_status"] = cat_slot_status306outputs["logit_cat_slot_value"] = cat_slot_value307noncat_slot_status, noncat_span_start, noncat_span_end = (308self._get_noncategorical_slot_goals(features))309outputs["logit_noncat_slot_status"] = noncat_slot_status310outputs["logit_noncat_slot_start"] = noncat_span_start311outputs["logit_noncat_slot_end"] = noncat_span_end312return outputs313
314def define_loss(self, features, outputs):315"""Obtain the loss of the model."""316# Intents.317# Shape: (batch_size, max_num_intents + 1).318intent_logits = outputs["logit_intent_status"]319# Shape: (batch_size, max_num_intents).320intent_labels = features["intent_status"]321# Add label corresponding to NONE intent.322num_active_intents = tf.expand_dims(323tf.reduce_sum(intent_labels, axis=1), axis=1)324none_intent_label = tf.ones_like(num_active_intents) - num_active_intents325# Shape: (batch_size, max_num_intents + 1).326onehot_intent_labels = tf.concat([none_intent_label, intent_labels], axis=1)327intent_loss = tf.losses.softmax_cross_entropy(328onehot_intent_labels,329intent_logits,330weights=features["is_real_example"])331
332# Requested slots.333# Shape: (batch_size, max_num_slots).334requested_slot_logits = outputs["logit_req_slot_status"]335requested_slot_labels = features["req_slot_status"]336max_num_requested_slots = requested_slot_labels.get_shape().as_list()[-1]337weights = tf.sequence_mask(338features["req_slot_num"], maxlen=max_num_requested_slots)339# Sigmoid cross entropy is used because more than one slots can be requested340# in a single utterance.341requested_slot_loss = tf.losses.sigmoid_cross_entropy(342requested_slot_labels, requested_slot_logits, weights=weights)343
344# Categorical slot status.345# Shape: (batch_size, max_num_cat_slots, 3).346cat_slot_status_logits = outputs["logit_cat_slot_status"]347cat_slot_status_labels = features["cat_slot_status"]348max_num_cat_slots = cat_slot_status_labels.get_shape().as_list()[-1]349one_hot_labels = tf.one_hot(cat_slot_status_labels, 3, dtype=tf.int32)350cat_weights = tf.sequence_mask(351features["cat_slot_num"], maxlen=max_num_cat_slots, dtype=tf.float32)352cat_slot_status_loss = tf.losses.softmax_cross_entropy(353tf.reshape(one_hot_labels, [-1, 3]),354tf.reshape(cat_slot_status_logits, [-1, 3]),355weights=tf.reshape(cat_weights, [-1]))356
357# Categorical slot values.358# Shape: (batch_size, max_num_cat_slots, max_num_slot_values).359cat_slot_value_logits = outputs["logit_cat_slot_value"]360cat_slot_value_labels = features["cat_slot_value"]361max_num_slot_values = cat_slot_value_logits.get_shape().as_list()[-1]362one_hot_labels = tf.one_hot(363cat_slot_value_labels, max_num_slot_values, dtype=tf.int32)364# Zero out losses for categorical slot value when the slot status is not365# active.366cat_loss_weight = tf.cast(367tf.equal(cat_slot_status_labels, data_utils.STATUS_ACTIVE), tf.float32)368cat_slot_value_loss = tf.losses.softmax_cross_entropy(369tf.reshape(one_hot_labels, [-1, max_num_slot_values]),370tf.reshape(cat_slot_value_logits, [-1, max_num_slot_values]),371weights=tf.reshape(cat_weights * cat_loss_weight, [-1]))372
373# Non-categorical slot status.374# Shape: (batch_size, max_num_noncat_slots, 3).375noncat_slot_status_logits = outputs["logit_noncat_slot_status"]376noncat_slot_status_labels = features["noncat_slot_status"]377max_num_noncat_slots = noncat_slot_status_labels.get_shape().as_list()[-1]378one_hot_labels = tf.one_hot(noncat_slot_status_labels, 3, dtype=tf.int32)379noncat_weights = tf.sequence_mask(380features["noncat_slot_num"],381maxlen=max_num_noncat_slots,382dtype=tf.float32)383# Logits for padded (invalid) values are already masked.384noncat_slot_status_loss = tf.losses.softmax_cross_entropy(385tf.reshape(one_hot_labels, [-1, 3]),386tf.reshape(noncat_slot_status_logits, [-1, 3]),387weights=tf.reshape(noncat_weights, [-1]))388
389# Non-categorical slot spans.390# Shape: (batch_size, max_num_noncat_slots, max_num_tokens).391span_start_logits = outputs["logit_noncat_slot_start"]392span_start_labels = features["noncat_slot_value_start"]393max_num_tokens = span_start_logits.get_shape().as_list()[-1]394onehot_start_labels = tf.one_hot(395span_start_labels, max_num_tokens, dtype=tf.int32)396# Shape: (batch_size, max_num_noncat_slots, max_num_tokens).397span_end_logits = outputs["logit_noncat_slot_end"]398span_end_labels = features["noncat_slot_value_end"]399onehot_end_labels = tf.one_hot(400span_end_labels, max_num_tokens, dtype=tf.int32)401# Zero out losses for non-categorical slot spans when the slot status is not402# active.403noncat_loss_weight = tf.cast(404tf.equal(noncat_slot_status_labels, data_utils.STATUS_ACTIVE),405tf.float32)406span_start_loss = tf.losses.softmax_cross_entropy(407tf.reshape(onehot_start_labels, [-1, max_num_tokens]),408tf.reshape(span_start_logits, [-1, max_num_tokens]),409weights=tf.reshape(noncat_weights * noncat_loss_weight, [-1]))410span_end_loss = tf.losses.softmax_cross_entropy(411tf.reshape(onehot_end_labels, [-1, max_num_tokens]),412tf.reshape(span_end_logits, [-1, max_num_tokens]),413weights=tf.reshape(noncat_weights * noncat_loss_weight, [-1]))414
415losses = {416"intent_loss": intent_loss,417"requested_slot_loss": requested_slot_loss,418"cat_slot_status_loss": cat_slot_status_loss,419"cat_slot_value_loss": cat_slot_value_loss,420"noncat_slot_status_loss": noncat_slot_status_loss,421"span_start_loss": span_start_loss,422"span_end_loss": span_end_loss,423}424for loss_name, loss in losses.items():425tf.summary.scalar(loss_name, loss)426return sum(losses.values()) / len(losses)427
428def define_predictions(self, features, outputs):429"""Define model predictions."""430predictions = {431"example_id": features["example_id"],432"service_id": features["service_id"],433"is_real_example": features["is_real_example"],434}435# Scores are output for each intent.436# Note that the intent indices are shifted by 1 to account for NONE intent.437predictions["intent_status"] = tf.argmax(438outputs["logit_intent_status"], axis=-1)439
440# Scores are output for each requested slot.441predictions["req_slot_status"] = tf.sigmoid(442outputs["logit_req_slot_status"])443
444# For categorical slots, the status of each slot and the predicted value are445# output.446predictions["cat_slot_status"] = tf.argmax(447outputs["logit_cat_slot_status"], axis=-1)448predictions["cat_slot_value"] = tf.argmax(449outputs["logit_cat_slot_value"], axis=-1)450
451# For non-categorical slots, the status of each slot and the indices for452# spans are output.453predictions["noncat_slot_status"] = tf.argmax(454outputs["logit_noncat_slot_status"], axis=-1)455start_scores = tf.nn.softmax(outputs["logit_noncat_slot_start"], axis=-1)456end_scores = tf.nn.softmax(outputs["logit_noncat_slot_end"], axis=-1)457_, max_num_slots, max_num_tokens = end_scores.get_shape().as_list()458batch_size = tf.shape(end_scores)[0]459# Find the span with the maximum sum of scores for start and end indices.460total_scores = (461tf.expand_dims(start_scores, axis=3) +462tf.expand_dims(end_scores, axis=2))463# Mask out scores where start_index > end_index.464start_idx = tf.reshape(tf.range(max_num_tokens), [1, 1, -1, 1])465end_idx = tf.reshape(tf.range(max_num_tokens), [1, 1, 1, -1])466invalid_index_mask = tf.tile((start_idx > end_idx),467[batch_size, max_num_slots, 1, 1])468total_scores = tf.where(invalid_index_mask, tf.zeros_like(total_scores),469total_scores)470max_span_index = tf.argmax(471tf.reshape(total_scores, [-1, max_num_slots, max_num_tokens**2]),472axis=-1)473span_start_index = tf.floordiv(max_span_index, max_num_tokens)474span_end_index = tf.floormod(max_span_index, max_num_tokens)475predictions["noncat_slot_start"] = span_start_index476predictions["noncat_slot_end"] = span_end_index477# Add inverse alignments.478predictions["noncat_alignment_start"] = features["noncat_alignment_start"]479predictions["noncat_alignment_end"] = features["noncat_alignment_end"]480
481return predictions482
483def _encode_utterances(self, features, is_training):484"""Encode system and user utterances using BERT."""485# Optain the embedded representation of system and user utterances in the486# turn and the corresponding token level representations.487bert_encoder = modeling.BertModel(488config=self._bert_config,489is_training=is_training,490input_ids=features["utt"],491input_mask=features["utt_mask"],492token_type_ids=features["utt_seg"],493use_one_hot_embeddings=self._use_one_hot_embeddings)494encoded_utterance = bert_encoder.get_pooled_output()495encoded_tokens = bert_encoder.get_sequence_output()496
497# Apply dropout in training mode.498encoded_utterance = tf.layers.dropout(499encoded_utterance, rate=FLAGS.dropout_rate, training=is_training)500encoded_tokens = tf.layers.dropout(501encoded_tokens, rate=FLAGS.dropout_rate, training=is_training)502return encoded_utterance, encoded_tokens503
504def _get_logits(self, element_embeddings, num_classes, name_scope):505"""Get logits for elements by conditioning on utterance embedding.506
507Args:
508element_embeddings: A tensor of shape (batch_size, num_elements,
509embedding_dim).
510num_classes: An int containing the number of classes for which logits are
511to be generated.
512name_scope: The name scope to be used for layers.
513
514Returns:
515A tensor of shape (batch_size, num_elements, num_classes) containing the
516logits.
517"""
518_, num_elements, embedding_dim = element_embeddings.get_shape().as_list()519# Project the utterance embeddings.520utterance_proj = tf.keras.layers.Dense(521units=embedding_dim,522activation=modeling.gelu,523name="{}_utterance_proj".format(name_scope))524utterance_embedding = utterance_proj(self._encoded_utterance)525# Combine the utterance and element embeddings.526repeat_utterance_embeddings = tf.tile(527tf.expand_dims(utterance_embedding, axis=1), [1, num_elements, 1])528utterance_element_emb = tf.concat(529[repeat_utterance_embeddings, element_embeddings], axis=2)530# Project the combined embeddings to obtain logits.531layer_1 = tf.keras.layers.Dense(532units=embedding_dim,533activation=modeling.gelu,534name="{}_projection_1".format(name_scope))535layer_2 = tf.keras.layers.Dense(536units=num_classes, name="{}_projection_2".format(name_scope))537return layer_2(layer_1(utterance_element_emb))538
539def _get_intents(self, features):540"""Obtain logits for intents."""541intent_embeddings = features["intent_emb"]542# Add a trainable vector for the NONE intent.543_, max_num_intents, embedding_dim = intent_embeddings.get_shape().as_list()544null_intent_embedding = tf.get_variable(545"null_intent_embedding",546shape=[1, 1, embedding_dim],547initializer=tf.truncated_normal_initializer(stddev=0.02))548batch_size = tf.shape(intent_embeddings)[0]549repeated_null_intent_embedding = tf.tile(null_intent_embedding,550[batch_size, 1, 1])551intent_embeddings = tf.concat(552[repeated_null_intent_embedding, intent_embeddings], axis=1)553
554logits = self._get_logits(intent_embeddings, 1, "intents")555# Shape: (batch_size, max_intents + 1)556logits = tf.squeeze(logits, axis=-1)557# Mask out logits for padded intents. 1 is added to account for NONE intent.558mask = tf.sequence_mask(559features["intent_num"] + 1, maxlen=max_num_intents + 1)560negative_logits = -0.7 * tf.ones_like(logits) * logits.dtype.max561return tf.where(mask, logits, negative_logits)562
563def _get_requested_slots(self, features):564"""Obtain logits for requested slots."""565slot_embeddings = features["req_slot_emb"]566logits = self._get_logits(slot_embeddings, 1, "requested_slots")567return tf.squeeze(logits, axis=-1)568
569def _get_categorical_slot_goals(self, features):570"""Obtain logits for status and values for categorical slots."""571# Predict the status of all categorical slots.572slot_embeddings = features["cat_slot_emb"]573status_logits = self._get_logits(slot_embeddings, 3,574"categorical_slot_status")575
576# Predict the goal value.577
578# Shape: (batch_size, max_categorical_slots, max_categorical_values,579# embedding_dim).580value_embeddings = features["cat_slot_value_emb"]581_, max_num_slots, max_num_values, embedding_dim = (582value_embeddings.get_shape().as_list())583value_embeddings_reshaped = tf.reshape(584value_embeddings, [-1, max_num_slots * max_num_values, embedding_dim])585value_logits = self._get_logits(value_embeddings_reshaped, 1,586"categorical_slot_values")587# Reshape to obtain the logits for all slots.588value_logits = tf.reshape(value_logits, [-1, max_num_slots, max_num_values])589# Mask out logits for padded slots and values because they will be590# softmaxed.591mask = tf.sequence_mask(592features["cat_slot_value_num"], maxlen=max_num_values)593negative_logits = -0.7 * tf.ones_like(value_logits) * value_logits.dtype.max594value_logits = tf.where(mask, value_logits, negative_logits)595return status_logits, value_logits596
597def _get_noncategorical_slot_goals(self, features):598"""Obtain logits for status and slot spans for non-categorical slots."""599# Predict the status of all non-categorical slots.600slot_embeddings = features["noncat_slot_emb"]601max_num_slots = slot_embeddings.get_shape().as_list()[1]602status_logits = self._get_logits(slot_embeddings, 3,603"noncategorical_slot_status")604
605# Predict the distribution for span indices.606token_embeddings = self._encoded_tokens607max_num_tokens = token_embeddings.get_shape().as_list()[1]608tiled_token_embeddings = tf.tile(609tf.expand_dims(token_embeddings, 1), [1, max_num_slots, 1, 1])610tiled_slot_embeddings = tf.tile(611tf.expand_dims(slot_embeddings, 2), [1, 1, max_num_tokens, 1])612# Shape: (batch_size, max_num_slots, max_num_tokens, 2 * embedding_dim).613slot_token_embeddings = tf.concat(614[tiled_slot_embeddings, tiled_token_embeddings], axis=3)615
616# Project the combined embeddings to obtain logits.617embedding_dim = slot_embeddings.get_shape().as_list()[-1]618layer_1 = tf.keras.layers.Dense(619units=embedding_dim,620activation=modeling.gelu,621name="noncat_spans_layer_1")622layer_2 = tf.keras.layers.Dense(units=2, name="noncat_spans_layer_2")623# Shape: (batch_size, max_num_slots, max_num_tokens, 2)624span_logits = layer_2(layer_1(slot_token_embeddings))625
626# Mask out invalid logits for padded tokens.627token_mask = features["utt_mask"] # Shape: (batch_size, max_num_tokens).628token_mask = tf.cast(token_mask, tf.bool)629tiled_token_mask = tf.tile(630tf.expand_dims(tf.expand_dims(token_mask, 1), 3),631[1, max_num_slots, 1, 2])632negative_logits = -0.7 * tf.ones_like(span_logits) * span_logits.dtype.max633span_logits = tf.where(tiled_token_mask, span_logits, negative_logits)634# Shape of both tensors: (batch_size, max_num_slots, max_num_tokens).635span_start_logits, span_end_logits = tf.unstack(span_logits, axis=3)636return status_logits, span_start_logits, span_end_logits637
638
639# Modified from run_classifier.model_fn_builder
640def _model_fn_builder(bert_config, init_checkpoint, learning_rate,641num_train_steps, num_warmup_steps, use_tpu,642use_one_hot_embeddings):643"""Returns `model_fn` closure for TPUEstimator."""644
645def model_fn(features, labels, mode, params): # pylint: disable=unused-argument646"""The `model_fn` for TPUEstimator."""647is_training = (mode == tf_estimator.ModeKeys.TRAIN)648
649schema_guided_dst = SchemaGuidedDST(bert_config, use_one_hot_embeddings)650outputs = schema_guided_dst.define_model(features, is_training)651if is_training:652total_loss = schema_guided_dst.define_loss(features, outputs)653else:654total_loss = tf.constant(0.0)655
656tvars = tf.trainable_variables()657scaffold_fn = None658if init_checkpoint:659assignment_map, _ = modeling.get_assignment_map_from_checkpoint(660tvars, init_checkpoint)661if use_tpu:662
663def tpu_scaffold():664tf.train.init_from_checkpoint(init_checkpoint, assignment_map)665return tf.train.Scaffold()666
667scaffold_fn = tpu_scaffold668else:669tf.train.init_from_checkpoint(init_checkpoint, assignment_map)670
671output_spec = None672if mode == tf_estimator.ModeKeys.TRAIN:673train_op = optimization.create_optimizer(total_loss, learning_rate,674num_train_steps,675num_warmup_steps, use_tpu)676global_step = tf.train.get_or_create_global_step()677logged_tensors = {678"global_step": global_step,679"total_loss": total_loss,680}681output_spec = tf_estimator.tpu.TPUEstimatorSpec(682mode=mode,683loss=total_loss,684train_op=train_op,685scaffold_fn=scaffold_fn,686training_hooks=[687tf.train.LoggingTensorHook(logged_tensors, every_n_iter=5)688])689
690elif mode == tf_estimator.ModeKeys.EVAL:691output_spec = tf_estimator.tpu.TPUEstimatorSpec(692mode=mode, loss=total_loss, scaffold_fn=scaffold_fn)693
694else: # mode == tf.estimator.ModeKeys.PREDICT695predictions = schema_guided_dst.define_predictions(features, outputs)696output_spec = tf_estimator.tpu.TPUEstimatorSpec(697mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)698
699return output_spec700
701return model_fn702
703
704def _create_dialog_examples(processor, dial_file):705"""Create dialog examples and save in the file."""706if not tf.io.gfile.exists(FLAGS.dialogues_example_dir):707tf.io.gfile.makedirs(FLAGS.dialogues_example_dir)708frame_examples = processor.get_dialog_examples(FLAGS.dataset_split)709data_utils.file_based_convert_examples_to_features(frame_examples,710processor.dataset_config,711dial_file)712
713
714def _create_schema_embeddings(bert_config, schema_embedding_file,715dataset_config):716"""Create schema embeddings and save it into file."""717if not tf.io.gfile.exists(FLAGS.schema_embedding_dir):718tf.io.gfile.makedirs(FLAGS.schema_embedding_dir)719is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2720schema_emb_run_config = tf_estimator.tpu.RunConfig(721master=FLAGS.master,722tpu_config=tf_estimator.tpu.TPUConfig(723num_shards=FLAGS.num_tpu_cores,724per_host_input_for_training=is_per_host))725
726schema_json_path = os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,727"schema.json")728schemas = schema.Schema(schema_json_path)729
730# Prepare BERT model for embedding a natural language descriptions.731bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")732schema_emb_model_fn = extract_schema_embedding.model_fn_builder(733bert_config=bert_config,734init_checkpoint=bert_init_ckpt,735use_tpu=FLAGS.use_tpu,736use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)737# If TPU is not available, this will fall back to normal Estimator on CPU738# or GPU.739schema_emb_estimator = tf_estimator.tpu.TPUEstimator(740use_tpu=FLAGS.use_tpu,741model_fn=schema_emb_model_fn,742config=schema_emb_run_config,743predict_batch_size=FLAGS.predict_batch_size)744vocab_file = os.path.join(FLAGS.bert_ckpt_dir, "vocab.txt")745tokenizer = tokenization.FullTokenizer(746vocab_file=vocab_file, do_lower_case=FLAGS.do_lower_case)747emb_generator = extract_schema_embedding.SchemaEmbeddingGenerator(748tokenizer, schema_emb_estimator, FLAGS.max_seq_length)749emb_generator.save_embeddings(schemas, schema_embedding_file, dataset_config)750
751
752def main(_):753vocab_file = os.path.join(FLAGS.bert_ckpt_dir, "vocab.txt")754task_name = FLAGS.task_name.lower()755if task_name not in config.DATASET_CONFIG:756raise ValueError("Task not found: %s" % (task_name))757dataset_config = config.DATASET_CONFIG[task_name]758processor = data_utils.Dstc8DataProcessor(759FLAGS.dstc8_data_dir,760dataset_config=dataset_config,761vocab_file=vocab_file,762do_lower_case=FLAGS.do_lower_case,763max_seq_length=FLAGS.max_seq_length,764log_data_warnings=FLAGS.log_data_warnings)765
766# Generate the dialogue examples if needed or specified.767dial_file_name = "{}_{}_examples.tf_record".format(task_name,768FLAGS.dataset_split)769dial_file = os.path.join(FLAGS.dialogues_example_dir, dial_file_name)770if not tf.io.gfile.exists(dial_file) or FLAGS.overwrite_dial_file:771tf.compat.v1.logging.info("Start generating the dialogue examples.")772_create_dialog_examples(processor, dial_file)773tf.compat.v1.logging.info("Finish generating the dialogue examples.")774
775# Generate the schema embeddings if needed or specified.776bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")777tokenization.validate_case_matches_checkpoint(778do_lower_case=FLAGS.do_lower_case, init_checkpoint=bert_init_ckpt)779
780bert_config = modeling.BertConfig.from_json_file(781os.path.join(FLAGS.bert_ckpt_dir, "bert_config.json"))782if FLAGS.max_seq_length > bert_config.max_position_embeddings:783raise ValueError(784"Cannot use sequence length %d because the BERT model "785"was only trained up to sequence length %d" %786(FLAGS.max_seq_length, bert_config.max_position_embeddings))787
788schema_embedding_file = os.path.join(789FLAGS.schema_embedding_dir,790"{}_pretrained_schema_embedding.npy".format(FLAGS.dataset_split))791if (not tf.io.gfile.exists(schema_embedding_file) or792FLAGS.overwrite_schema_emb_file):793tf.compat.v1.logging.info("Start generating the schema embeddings.")794_create_schema_embeddings(bert_config, schema_embedding_file,795dataset_config)796tf.compat.v1.logging.info("Finish generating the schema embeddings.")797
798# Create estimator for training or inference.799if not tf.io.gfile.exists(FLAGS.output_dir):800tf.io.gfile.makedirs(FLAGS.output_dir)801
802tpu_cluster_resolver = None803if FLAGS.use_tpu and FLAGS.tpu_name:804tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(805FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)806
807is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2808run_config = tf_estimator.tpu.RunConfig(809cluster=tpu_cluster_resolver,810master=FLAGS.master,811model_dir=FLAGS.output_dir,812save_checkpoints_steps=FLAGS.save_checkpoints_steps,813keep_checkpoint_max=None,814tpu_config=tf_estimator.tpu.TPUConfig(815# Recommended value is number of global steps for next checkpoint.816iterations_per_loop=FLAGS.save_checkpoints_steps,817num_shards=FLAGS.num_tpu_cores,818per_host_input_for_training=is_per_host))819
820num_train_steps = None821num_warmup_steps = None822if FLAGS.run_mode == "train":823num_train_examples = processor.get_num_dialog_examples(FLAGS.dataset_split)824num_train_steps = int(num_train_examples / FLAGS.train_batch_size *825FLAGS.num_train_epochs)826num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)827
828bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")829model_fn = _model_fn_builder(830bert_config=bert_config,831init_checkpoint=bert_init_ckpt,832learning_rate=FLAGS.learning_rate,833num_train_steps=num_train_steps,834num_warmup_steps=num_warmup_steps,835use_tpu=FLAGS.use_tpu,836use_one_hot_embeddings=FLAGS.use_tpu)837
838# If TPU is not available, this will fall back to normal Estimator on CPU839# or GPU.840estimator = tf_estimator.tpu.TPUEstimator(841use_tpu=FLAGS.use_tpu,842model_fn=model_fn,843config=run_config,844train_batch_size=FLAGS.train_batch_size,845eval_batch_size=FLAGS.eval_batch_size,846predict_batch_size=FLAGS.predict_batch_size)847
848if FLAGS.run_mode == "train":849# Train the model.850tf.compat.v1.logging.info("***** Running training *****")851tf.compat.v1.logging.info(" Num dial examples = %d", num_train_examples)852tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)853tf.compat.v1.logging.info(" Num steps = %d", num_train_steps)854train_input_fn = _file_based_input_fn_builder(855dataset_config=dataset_config,856input_dial_file=dial_file,857schema_embedding_file=schema_embedding_file,858is_training=True,859drop_remainder=True)860estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)861elif FLAGS.run_mode == "predict":862# Run inference to obtain model predictions.863num_actual_predict_examples = processor.get_num_dialog_examples(864FLAGS.dataset_split)865
866tf.compat.v1.logging.info("***** Running prediction *****")867tf.compat.v1.logging.info(" Num actual examples = %d",868num_actual_predict_examples)869tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size)870
871predict_input_fn = _file_based_input_fn_builder(872dataset_config=dataset_config,873input_dial_file=dial_file,874schema_embedding_file=schema_embedding_file,875is_training=False,876drop_remainder=FLAGS.use_tpu)877
878input_json_files = [879os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,880"dialogues_{:03d}.json".format(fid))881for fid in dataset_config.file_ranges[FLAGS.dataset_split]882]883schema_json_file = os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,884"schema.json")885
886ckpt_nums = [num for num in FLAGS.eval_ckpt.split(",") if num]887if not ckpt_nums:888raise ValueError("No checkpoints assigned for prediction.")889for ckpt_num in ckpt_nums:890tf.compat.v1.logging.info("***** Predict results for %s set *****",891FLAGS.dataset_split)892
893predictions = estimator.predict(894input_fn=predict_input_fn,895checkpoint_path=os.path.join(FLAGS.output_dir,896"model.ckpt-%s" % ckpt_num))897
898# Write predictions to file in DSTC8 format.899dataset_mark = os.path.basename(FLAGS.dstc8_data_dir)900prediction_dir = os.path.join(901FLAGS.output_dir, "pred_res_{}_{}_{}_{}".format(902int(ckpt_num), FLAGS.dataset_split, task_name, dataset_mark))903if not tf.io.gfile.exists(prediction_dir):904tf.io.gfile.makedirs(prediction_dir)905pred_utils.write_predictions_to_file(predictions, input_json_files,906schema_json_file, prediction_dir)907
908
909if __name__ == "__main__":910flags.mark_flag_as_required("dstc8_data_dir")911flags.mark_flag_as_required("bert_ckpt_dir")912flags.mark_flag_as_required("dataset_split")913flags.mark_flag_as_required("schema_embedding_dir")914flags.mark_flag_as_required("dialogues_example_dir")915flags.mark_flag_as_required("task_name")916flags.mark_flag_as_required("output_dir")917tf.compat.v1.app.run(main)918