google-research

Форк
0
917 строк · 38.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Baseline model for Schema-guided Dialogue State Tracking.
17

18
Adapted from
19
https://github.com/google-research/bert/blob/master/run_classifier.py
20
"""
21

22
from __future__ import absolute_import
23
from __future__ import division
24
from __future__ import print_function
25

26
import collections
27
import os
28

29
import numpy as np
30
import tensorflow.compat.v1 as tf
31
from tensorflow.compat.v1 import estimator as tf_estimator
32

33

34
from schema_guided_dst import schema
35
from schema_guided_dst.baseline import config
36
from schema_guided_dst.baseline import data_utils
37
from schema_guided_dst.baseline import extract_schema_embedding
38
from schema_guided_dst.baseline import pred_utils
39
from schema_guided_dst.baseline.bert import modeling
40
from schema_guided_dst.baseline.bert import optimization
41
from schema_guided_dst.baseline.bert import tokenization
42

43

44
flags = tf.compat.v1.flags
45
FLAGS = flags.FLAGS
46

47
# BERT based utterance encoder related flags.
48
flags.DEFINE_string("bert_ckpt_dir", None,
49
                    "Directory containing pre-trained BERT checkpoint.")
50

51
flags.DEFINE_bool(
52
    "do_lower_case", False,
53
    "Whether to lower case the input text. Should be True for uncased "
54
    "models and False for cased models.")
55

56
flags.DEFINE_integer(
57
    "max_seq_length", 80,
58
    "The maximum total input sequence length after WordPiece tokenization. "
59
    "Sequences longer than this will be truncated, and sequences shorter "
60
    "than this will be padded.")
61

62
flags.DEFINE_float("dropout_rate", 0.1,
63
                   "Dropout rate for BERT representations.")
64

65
# Hyperparameters and optimization related flags.
66
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
67

68
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
69

70
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
71

72
flags.DEFINE_float("learning_rate", 1e-4, "The initial learning rate for Adam.")
73

74
flags.DEFINE_float("num_train_epochs", 80.0,
75
                   "Total number of training epochs to perform.")
76
flags.DEFINE_float(
77
    "warmup_proportion", 0.1,
78
    "Proportion of training to perform linear learning rate warmup for. "
79
    "E.g., 0.1 = 10% of training.")
80

81
flags.DEFINE_integer("save_checkpoints_steps", 1000,
82
                     "How often to save the model checkpoint.")
83

84
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
85

86
flags.DEFINE_string(
87
    "tpu_name", None,
88
    "The Cloud TPU to use for training. This should be either the name "
89
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
90
    "url.")
91

92
flags.DEFINE_string(
93
    "tpu_zone", None,
94
    "[Optional] GCE zone where the Cloud TPU is located in. If not "
95
    "specified, we will attempt to automatically detect the GCE project from "
96
    "metadata.")
97

98
flags.DEFINE_string(
99
    "gcp_project", None,
100
    "[Optional] Project name for the Cloud TPU-enabled project. If not "
101
    "specified, we will attempt to automatically detect the GCE project from "
102
    "metadata.")
103

104
flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
105

106
flags.DEFINE_integer(
107
    "num_tpu_cores", 8,
108
    "Only used if `use_tpu` is True. Total number of TPU cores to use.")
109

110
flags.DEFINE_bool(
111
    "use_one_hot_embeddings", False,
112
    "If True, tf.one_hot will be used for embedding lookups, otherwise "
113
    "tf.nn.embedding_lookup will be used. On TPUs, this should be True "
114
    "since it is much faster.")
115

116
# Input and output paths and other flags.
117
flags.DEFINE_enum("task_name", None, config.DATASET_CONFIG.keys(),
118
                  "The name of the task to train.")
119

120
flags.DEFINE_string(
121
    "dstc8_data_dir", None,
122
    "Directory for the downloaded DSTC8 data, which contains the dialogue files"
123
    " and schema files of all datasets (eg train, dev)")
124

125
flags.DEFINE_enum("run_mode", None, ["train", "predict"],
126
                  "The mode to run the script in.")
127

128
flags.DEFINE_string(
129
    "output_dir", None,
130
    "The output directory where the model checkpoints will be written.")
131

132
flags.DEFINE_string(
133
    "schema_embedding_dir", None,
134
    "Directory where .npy file for embedding of entities (slots, values,"
135
    " intents) in the dataset_split's schema are stored.")
136

137
flags.DEFINE_string(
138
    "dialogues_example_dir", None,
139
    "Directory where tf.record of DSTC8 dialogues data are stored.")
140

141
flags.DEFINE_enum("dataset_split", None, ["train", "dev", "test"],
142
                  "Dataset split for training / prediction.")
143

144
flags.DEFINE_string(
145
    "eval_ckpt", "",
146
    "Comma separated numbers, each being a step number of model checkpoint"
147
    " which makes predictions.")
148

149
flags.DEFINE_bool(
150
    "overwrite_dial_file", False,
151
    "Whether to generate a new Tf.record file saving the dialogue examples.")
152

153
flags.DEFINE_bool(
154
    "overwrite_schema_emb_file", False,
155
    "Whether to generate a new schema_emb file saving the schemas' embeddings.")
156

157
flags.DEFINE_bool(
158
    "log_data_warnings", False,
159
    "If True, warnings created using data processing are logged.")
160

161

162
# Modified from run_classifier.file_based_input_fn_builder
163
def _file_based_input_fn_builder(dataset_config, input_dial_file,
164
                                 schema_embedding_file, is_training,
165
                                 drop_remainder):
166
  """Creates an `input_fn` closure to be passed to TPUEstimator."""
167

168
  max_num_cat_slot = dataset_config.max_num_cat_slot
169
  max_num_noncat_slot = dataset_config.max_num_noncat_slot
170
  max_num_total_slot = max_num_cat_slot + max_num_noncat_slot
171
  max_num_intent = dataset_config.max_num_intent
172
  max_utt_len = FLAGS.max_seq_length
173

174
  name_to_features = {
175
      "example_id":
176
          tf.io.FixedLenFeature([], tf.string),
177
      "is_real_example":
178
          tf.io.FixedLenFeature([], tf.int64),
179
      "service_id":
180
          tf.io.FixedLenFeature([], tf.int64),
181
      "utt":
182
          tf.io.FixedLenFeature([max_utt_len], tf.int64),
183
      "utt_mask":
184
          tf.io.FixedLenFeature([max_utt_len], tf.int64),
185
      "utt_seg":
186
          tf.io.FixedLenFeature([max_utt_len], tf.int64),
187
      "cat_slot_num":
188
          tf.io.FixedLenFeature([], tf.int64),
189
      "cat_slot_status":
190
          tf.io.FixedLenFeature([max_num_cat_slot], tf.int64),
191
      "cat_slot_value_num":
192
          tf.io.FixedLenFeature([max_num_cat_slot], tf.int64),
193
      "cat_slot_value":
194
          tf.io.FixedLenFeature([max_num_cat_slot], tf.int64),
195
      "noncat_slot_num":
196
          tf.io.FixedLenFeature([], tf.int64),
197
      "noncat_slot_status":
198
          tf.io.FixedLenFeature([max_num_noncat_slot], tf.int64),
199
      "noncat_slot_value_start":
200
          tf.io.FixedLenFeature([max_num_noncat_slot], tf.int64),
201
      "noncat_slot_value_end":
202
          tf.io.FixedLenFeature([max_num_noncat_slot], tf.int64),
203
      "noncat_alignment_start":
204
          tf.io.FixedLenFeature([max_utt_len], tf.int64),
205
      "noncat_alignment_end":
206
          tf.io.FixedLenFeature([max_utt_len], tf.int64),
207
      "req_slot_num":
208
          tf.io.FixedLenFeature([], tf.int64),
209
      "req_slot_status":
210
          tf.io.FixedLenFeature([max_num_total_slot], tf.int64),
211
      "intent_num":
212
          tf.io.FixedLenFeature([], tf.int64),
213
      "intent_status":
214
          tf.io.FixedLenFeature([max_num_intent], tf.int64),
215
  }
216
  with tf.io.gfile.GFile(schema_embedding_file, "rb") as f:
217
    schema_data = np.load(f, allow_pickle=True)
218

219
  # Convert from list of dict to dict of list
220
  schema_data_dict = collections.defaultdict(list)
221
  for service in schema_data:
222
    schema_data_dict["cat_slot_emb"].append(service["cat_slot_emb"])
223
    schema_data_dict["cat_slot_value_emb"].append(service["cat_slot_value_emb"])
224
    schema_data_dict["noncat_slot_emb"].append(service["noncat_slot_emb"])
225
    schema_data_dict["req_slot_emb"].append(service["req_slot_emb"])
226
    schema_data_dict["intent_emb"].append(service["intent_emb"])
227

228
  def _decode_record(record, name_to_features, schema_tensors):
229
    """Decodes a record to a TensorFlow example."""
230

231
    example = tf.parse_single_example(record, name_to_features)
232

233
    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
234
    # So cast all int64 to int32.
235
    for name in list(example.keys()):
236
      t = example[name]
237
      if t.dtype == tf.int64:
238
        t = tf.cast(t, tf.int32)
239
      example[name] = t
240

241
    # Here we need to insert schema's entity embedding to each example.
242

243
    # Shapes for reference: (all have type tf.float32)
244
    # "cat_slot_emb": [max_num_cat_slot, hidden_dim]
245
    # "cat_slot_value_emb": [max_num_cat_slot, max_num_value, hidden_dim]
246
    # "noncat_slot_emb": [max_num_noncat_slot, hidden_dim]
247
    # "req_slot_emb": [max_num_total_slot, hidden_dim]
248
    # "intent_emb": [max_num_intent, hidden_dim]
249

250
    service_id = example["service_id"]
251
    for key, value in schema_tensors.items():
252
      example[key] = value[service_id]
253
    return example
254

255
  def input_fn(params):
256
    """The actual input function."""
257
    batch_size = params["batch_size"]
258

259
    # For training, we want a lot of parallel reading and shuffling.
260
    # For eval, we want no shuffling and parallel reading doesn't matter.
261
    d = tf.data.TFRecordDataset(input_dial_file)
262
    # Uncomment for debugging
263
    # d = d.take(12)
264
    if is_training:
265
      d = d.repeat()
266
      d = d.shuffle(buffer_size=100)
267
    schema_tensors = {}
268
    for key, array in schema_data_dict.items():
269
      schema_tensors[key] = tf.convert_to_tensor(np.asarray(array, np.float32))
270

271
    d = d.apply(
272
        tf.data.experimental.map_and_batch(
273
            lambda rec: _decode_record(rec, name_to_features, schema_tensors),
274
            batch_size=batch_size,
275
            drop_remainder=drop_remainder))
276
    return d
277

278
  return input_fn
279

280

281
class SchemaGuidedDST(object):
282
  """Baseline model for schema guided dialogue state tracking."""
283

284
  def __init__(self, bert_config, use_one_hot_embeddings):
285
    self._bert_config = bert_config
286
    self._use_one_hot_embeddings = use_one_hot_embeddings
287

288
  def define_model(self, features, is_training):
289
    """Define the model computation.
290

291
    Args:
292
      features: A dict mapping feature names to corresponding tensors.
293
      is_training: A boolean which is True when the model is being trained.
294

295
    Returns:
296
      outputs: A dict mapping output names to corresponding tensors.
297
    """
298
    # Encode the utterances using BERT.
299
    self._encoded_utterance, self._encoded_tokens = (
300
        self._encode_utterances(features, is_training))
301
    outputs = {}
302
    outputs["logit_intent_status"] = self._get_intents(features)
303
    outputs["logit_req_slot_status"] = self._get_requested_slots(features)
304
    cat_slot_status, cat_slot_value = self._get_categorical_slot_goals(features)
305
    outputs["logit_cat_slot_status"] = cat_slot_status
306
    outputs["logit_cat_slot_value"] = cat_slot_value
307
    noncat_slot_status, noncat_span_start, noncat_span_end = (
308
        self._get_noncategorical_slot_goals(features))
309
    outputs["logit_noncat_slot_status"] = noncat_slot_status
310
    outputs["logit_noncat_slot_start"] = noncat_span_start
311
    outputs["logit_noncat_slot_end"] = noncat_span_end
312
    return outputs
313

314
  def define_loss(self, features, outputs):
315
    """Obtain the loss of the model."""
316
    # Intents.
317
    # Shape: (batch_size, max_num_intents + 1).
318
    intent_logits = outputs["logit_intent_status"]
319
    # Shape: (batch_size, max_num_intents).
320
    intent_labels = features["intent_status"]
321
    # Add label corresponding to NONE intent.
322
    num_active_intents = tf.expand_dims(
323
        tf.reduce_sum(intent_labels, axis=1), axis=1)
324
    none_intent_label = tf.ones_like(num_active_intents) - num_active_intents
325
    # Shape: (batch_size, max_num_intents + 1).
326
    onehot_intent_labels = tf.concat([none_intent_label, intent_labels], axis=1)
327
    intent_loss = tf.losses.softmax_cross_entropy(
328
        onehot_intent_labels,
329
        intent_logits,
330
        weights=features["is_real_example"])
331

332
    # Requested slots.
333
    # Shape: (batch_size, max_num_slots).
334
    requested_slot_logits = outputs["logit_req_slot_status"]
335
    requested_slot_labels = features["req_slot_status"]
336
    max_num_requested_slots = requested_slot_labels.get_shape().as_list()[-1]
337
    weights = tf.sequence_mask(
338
        features["req_slot_num"], maxlen=max_num_requested_slots)
339
    # Sigmoid cross entropy is used because more than one slots can be requested
340
    # in a single utterance.
341
    requested_slot_loss = tf.losses.sigmoid_cross_entropy(
342
        requested_slot_labels, requested_slot_logits, weights=weights)
343

344
    # Categorical slot status.
345
    # Shape: (batch_size, max_num_cat_slots, 3).
346
    cat_slot_status_logits = outputs["logit_cat_slot_status"]
347
    cat_slot_status_labels = features["cat_slot_status"]
348
    max_num_cat_slots = cat_slot_status_labels.get_shape().as_list()[-1]
349
    one_hot_labels = tf.one_hot(cat_slot_status_labels, 3, dtype=tf.int32)
350
    cat_weights = tf.sequence_mask(
351
        features["cat_slot_num"], maxlen=max_num_cat_slots, dtype=tf.float32)
352
    cat_slot_status_loss = tf.losses.softmax_cross_entropy(
353
        tf.reshape(one_hot_labels, [-1, 3]),
354
        tf.reshape(cat_slot_status_logits, [-1, 3]),
355
        weights=tf.reshape(cat_weights, [-1]))
356

357
    # Categorical slot values.
358
    # Shape: (batch_size, max_num_cat_slots, max_num_slot_values).
359
    cat_slot_value_logits = outputs["logit_cat_slot_value"]
360
    cat_slot_value_labels = features["cat_slot_value"]
361
    max_num_slot_values = cat_slot_value_logits.get_shape().as_list()[-1]
362
    one_hot_labels = tf.one_hot(
363
        cat_slot_value_labels, max_num_slot_values, dtype=tf.int32)
364
    # Zero out losses for categorical slot value when the slot status is not
365
    # active.
366
    cat_loss_weight = tf.cast(
367
        tf.equal(cat_slot_status_labels, data_utils.STATUS_ACTIVE), tf.float32)
368
    cat_slot_value_loss = tf.losses.softmax_cross_entropy(
369
        tf.reshape(one_hot_labels, [-1, max_num_slot_values]),
370
        tf.reshape(cat_slot_value_logits, [-1, max_num_slot_values]),
371
        weights=tf.reshape(cat_weights * cat_loss_weight, [-1]))
372

373
    # Non-categorical slot status.
374
    # Shape: (batch_size, max_num_noncat_slots, 3).
375
    noncat_slot_status_logits = outputs["logit_noncat_slot_status"]
376
    noncat_slot_status_labels = features["noncat_slot_status"]
377
    max_num_noncat_slots = noncat_slot_status_labels.get_shape().as_list()[-1]
378
    one_hot_labels = tf.one_hot(noncat_slot_status_labels, 3, dtype=tf.int32)
379
    noncat_weights = tf.sequence_mask(
380
        features["noncat_slot_num"],
381
        maxlen=max_num_noncat_slots,
382
        dtype=tf.float32)
383
    # Logits for padded (invalid) values are already masked.
384
    noncat_slot_status_loss = tf.losses.softmax_cross_entropy(
385
        tf.reshape(one_hot_labels, [-1, 3]),
386
        tf.reshape(noncat_slot_status_logits, [-1, 3]),
387
        weights=tf.reshape(noncat_weights, [-1]))
388

389
    # Non-categorical slot spans.
390
    # Shape: (batch_size, max_num_noncat_slots, max_num_tokens).
391
    span_start_logits = outputs["logit_noncat_slot_start"]
392
    span_start_labels = features["noncat_slot_value_start"]
393
    max_num_tokens = span_start_logits.get_shape().as_list()[-1]
394
    onehot_start_labels = tf.one_hot(
395
        span_start_labels, max_num_tokens, dtype=tf.int32)
396
    # Shape: (batch_size, max_num_noncat_slots, max_num_tokens).
397
    span_end_logits = outputs["logit_noncat_slot_end"]
398
    span_end_labels = features["noncat_slot_value_end"]
399
    onehot_end_labels = tf.one_hot(
400
        span_end_labels, max_num_tokens, dtype=tf.int32)
401
    # Zero out losses for non-categorical slot spans when the slot status is not
402
    # active.
403
    noncat_loss_weight = tf.cast(
404
        tf.equal(noncat_slot_status_labels, data_utils.STATUS_ACTIVE),
405
        tf.float32)
406
    span_start_loss = tf.losses.softmax_cross_entropy(
407
        tf.reshape(onehot_start_labels, [-1, max_num_tokens]),
408
        tf.reshape(span_start_logits, [-1, max_num_tokens]),
409
        weights=tf.reshape(noncat_weights * noncat_loss_weight, [-1]))
410
    span_end_loss = tf.losses.softmax_cross_entropy(
411
        tf.reshape(onehot_end_labels, [-1, max_num_tokens]),
412
        tf.reshape(span_end_logits, [-1, max_num_tokens]),
413
        weights=tf.reshape(noncat_weights * noncat_loss_weight, [-1]))
414

415
    losses = {
416
        "intent_loss": intent_loss,
417
        "requested_slot_loss": requested_slot_loss,
418
        "cat_slot_status_loss": cat_slot_status_loss,
419
        "cat_slot_value_loss": cat_slot_value_loss,
420
        "noncat_slot_status_loss": noncat_slot_status_loss,
421
        "span_start_loss": span_start_loss,
422
        "span_end_loss": span_end_loss,
423
    }
424
    for loss_name, loss in losses.items():
425
      tf.summary.scalar(loss_name, loss)
426
    return sum(losses.values()) / len(losses)
427

428
  def define_predictions(self, features, outputs):
429
    """Define model predictions."""
430
    predictions = {
431
        "example_id": features["example_id"],
432
        "service_id": features["service_id"],
433
        "is_real_example": features["is_real_example"],
434
    }
435
    # Scores are output for each intent.
436
    # Note that the intent indices are shifted by 1 to account for NONE intent.
437
    predictions["intent_status"] = tf.argmax(
438
        outputs["logit_intent_status"], axis=-1)
439

440
    # Scores are output for each requested slot.
441
    predictions["req_slot_status"] = tf.sigmoid(
442
        outputs["logit_req_slot_status"])
443

444
    # For categorical slots, the status of each slot and the predicted value are
445
    # output.
446
    predictions["cat_slot_status"] = tf.argmax(
447
        outputs["logit_cat_slot_status"], axis=-1)
448
    predictions["cat_slot_value"] = tf.argmax(
449
        outputs["logit_cat_slot_value"], axis=-1)
450

451
    # For non-categorical slots, the status of each slot and the indices for
452
    # spans are output.
453
    predictions["noncat_slot_status"] = tf.argmax(
454
        outputs["logit_noncat_slot_status"], axis=-1)
455
    start_scores = tf.nn.softmax(outputs["logit_noncat_slot_start"], axis=-1)
456
    end_scores = tf.nn.softmax(outputs["logit_noncat_slot_end"], axis=-1)
457
    _, max_num_slots, max_num_tokens = end_scores.get_shape().as_list()
458
    batch_size = tf.shape(end_scores)[0]
459
    # Find the span with the maximum sum of scores for start and end indices.
460
    total_scores = (
461
        tf.expand_dims(start_scores, axis=3) +
462
        tf.expand_dims(end_scores, axis=2))
463
    # Mask out scores where start_index > end_index.
464
    start_idx = tf.reshape(tf.range(max_num_tokens), [1, 1, -1, 1])
465
    end_idx = tf.reshape(tf.range(max_num_tokens), [1, 1, 1, -1])
466
    invalid_index_mask = tf.tile((start_idx > end_idx),
467
                                 [batch_size, max_num_slots, 1, 1])
468
    total_scores = tf.where(invalid_index_mask, tf.zeros_like(total_scores),
469
                            total_scores)
470
    max_span_index = tf.argmax(
471
        tf.reshape(total_scores, [-1, max_num_slots, max_num_tokens**2]),
472
        axis=-1)
473
    span_start_index = tf.floordiv(max_span_index, max_num_tokens)
474
    span_end_index = tf.floormod(max_span_index, max_num_tokens)
475
    predictions["noncat_slot_start"] = span_start_index
476
    predictions["noncat_slot_end"] = span_end_index
477
    # Add inverse alignments.
478
    predictions["noncat_alignment_start"] = features["noncat_alignment_start"]
479
    predictions["noncat_alignment_end"] = features["noncat_alignment_end"]
480

481
    return predictions
482

483
  def _encode_utterances(self, features, is_training):
484
    """Encode system and user utterances using BERT."""
485
    # Optain the embedded representation of system and user utterances in the
486
    # turn and the corresponding token level representations.
487
    bert_encoder = modeling.BertModel(
488
        config=self._bert_config,
489
        is_training=is_training,
490
        input_ids=features["utt"],
491
        input_mask=features["utt_mask"],
492
        token_type_ids=features["utt_seg"],
493
        use_one_hot_embeddings=self._use_one_hot_embeddings)
494
    encoded_utterance = bert_encoder.get_pooled_output()
495
    encoded_tokens = bert_encoder.get_sequence_output()
496

497
    # Apply dropout in training mode.
498
    encoded_utterance = tf.layers.dropout(
499
        encoded_utterance, rate=FLAGS.dropout_rate, training=is_training)
500
    encoded_tokens = tf.layers.dropout(
501
        encoded_tokens, rate=FLAGS.dropout_rate, training=is_training)
502
    return encoded_utterance, encoded_tokens
503

504
  def _get_logits(self, element_embeddings, num_classes, name_scope):
505
    """Get logits for elements by conditioning on utterance embedding.
506

507
    Args:
508
      element_embeddings: A tensor of shape (batch_size, num_elements,
509
        embedding_dim).
510
      num_classes: An int containing the number of classes for which logits are
511
        to be generated.
512
      name_scope: The name scope to be used for layers.
513

514
    Returns:
515
      A tensor of shape (batch_size, num_elements, num_classes) containing the
516
      logits.
517
    """
518
    _, num_elements, embedding_dim = element_embeddings.get_shape().as_list()
519
    # Project the utterance embeddings.
520
    utterance_proj = tf.keras.layers.Dense(
521
        units=embedding_dim,
522
        activation=modeling.gelu,
523
        name="{}_utterance_proj".format(name_scope))
524
    utterance_embedding = utterance_proj(self._encoded_utterance)
525
    # Combine the utterance and element embeddings.
526
    repeat_utterance_embeddings = tf.tile(
527
        tf.expand_dims(utterance_embedding, axis=1), [1, num_elements, 1])
528
    utterance_element_emb = tf.concat(
529
        [repeat_utterance_embeddings, element_embeddings], axis=2)
530
    # Project the combined embeddings to obtain logits.
531
    layer_1 = tf.keras.layers.Dense(
532
        units=embedding_dim,
533
        activation=modeling.gelu,
534
        name="{}_projection_1".format(name_scope))
535
    layer_2 = tf.keras.layers.Dense(
536
        units=num_classes, name="{}_projection_2".format(name_scope))
537
    return layer_2(layer_1(utterance_element_emb))
538

539
  def _get_intents(self, features):
540
    """Obtain logits for intents."""
541
    intent_embeddings = features["intent_emb"]
542
    # Add a trainable vector for the NONE intent.
543
    _, max_num_intents, embedding_dim = intent_embeddings.get_shape().as_list()
544
    null_intent_embedding = tf.get_variable(
545
        "null_intent_embedding",
546
        shape=[1, 1, embedding_dim],
547
        initializer=tf.truncated_normal_initializer(stddev=0.02))
548
    batch_size = tf.shape(intent_embeddings)[0]
549
    repeated_null_intent_embedding = tf.tile(null_intent_embedding,
550
                                             [batch_size, 1, 1])
551
    intent_embeddings = tf.concat(
552
        [repeated_null_intent_embedding, intent_embeddings], axis=1)
553

554
    logits = self._get_logits(intent_embeddings, 1, "intents")
555
    # Shape: (batch_size, max_intents + 1)
556
    logits = tf.squeeze(logits, axis=-1)
557
    # Mask out logits for padded intents. 1 is added to account for NONE intent.
558
    mask = tf.sequence_mask(
559
        features["intent_num"] + 1, maxlen=max_num_intents + 1)
560
    negative_logits = -0.7 * tf.ones_like(logits) * logits.dtype.max
561
    return tf.where(mask, logits, negative_logits)
562

563
  def _get_requested_slots(self, features):
564
    """Obtain logits for requested slots."""
565
    slot_embeddings = features["req_slot_emb"]
566
    logits = self._get_logits(slot_embeddings, 1, "requested_slots")
567
    return tf.squeeze(logits, axis=-1)
568

569
  def _get_categorical_slot_goals(self, features):
570
    """Obtain logits for status and values for categorical slots."""
571
    # Predict the status of all categorical slots.
572
    slot_embeddings = features["cat_slot_emb"]
573
    status_logits = self._get_logits(slot_embeddings, 3,
574
                                     "categorical_slot_status")
575

576
    # Predict the goal value.
577

578
    # Shape: (batch_size, max_categorical_slots, max_categorical_values,
579
    # embedding_dim).
580
    value_embeddings = features["cat_slot_value_emb"]
581
    _, max_num_slots, max_num_values, embedding_dim = (
582
        value_embeddings.get_shape().as_list())
583
    value_embeddings_reshaped = tf.reshape(
584
        value_embeddings, [-1, max_num_slots * max_num_values, embedding_dim])
585
    value_logits = self._get_logits(value_embeddings_reshaped, 1,
586
                                    "categorical_slot_values")
587
    # Reshape to obtain the logits for all slots.
588
    value_logits = tf.reshape(value_logits, [-1, max_num_slots, max_num_values])
589
    # Mask out logits for padded slots and values because they will be
590
    # softmaxed.
591
    mask = tf.sequence_mask(
592
        features["cat_slot_value_num"], maxlen=max_num_values)
593
    negative_logits = -0.7 * tf.ones_like(value_logits) * value_logits.dtype.max
594
    value_logits = tf.where(mask, value_logits, negative_logits)
595
    return status_logits, value_logits
596

597
  def _get_noncategorical_slot_goals(self, features):
598
    """Obtain logits for status and slot spans for non-categorical slots."""
599
    # Predict the status of all non-categorical slots.
600
    slot_embeddings = features["noncat_slot_emb"]
601
    max_num_slots = slot_embeddings.get_shape().as_list()[1]
602
    status_logits = self._get_logits(slot_embeddings, 3,
603
                                     "noncategorical_slot_status")
604

605
    # Predict the distribution for span indices.
606
    token_embeddings = self._encoded_tokens
607
    max_num_tokens = token_embeddings.get_shape().as_list()[1]
608
    tiled_token_embeddings = tf.tile(
609
        tf.expand_dims(token_embeddings, 1), [1, max_num_slots, 1, 1])
610
    tiled_slot_embeddings = tf.tile(
611
        tf.expand_dims(slot_embeddings, 2), [1, 1, max_num_tokens, 1])
612
    # Shape: (batch_size, max_num_slots, max_num_tokens, 2 * embedding_dim).
613
    slot_token_embeddings = tf.concat(
614
        [tiled_slot_embeddings, tiled_token_embeddings], axis=3)
615

616
    # Project the combined embeddings to obtain logits.
617
    embedding_dim = slot_embeddings.get_shape().as_list()[-1]
618
    layer_1 = tf.keras.layers.Dense(
619
        units=embedding_dim,
620
        activation=modeling.gelu,
621
        name="noncat_spans_layer_1")
622
    layer_2 = tf.keras.layers.Dense(units=2, name="noncat_spans_layer_2")
623
    # Shape: (batch_size, max_num_slots, max_num_tokens, 2)
624
    span_logits = layer_2(layer_1(slot_token_embeddings))
625

626
    # Mask out invalid logits for padded tokens.
627
    token_mask = features["utt_mask"]  # Shape: (batch_size, max_num_tokens).
628
    token_mask = tf.cast(token_mask, tf.bool)
629
    tiled_token_mask = tf.tile(
630
        tf.expand_dims(tf.expand_dims(token_mask, 1), 3),
631
        [1, max_num_slots, 1, 2])
632
    negative_logits = -0.7 * tf.ones_like(span_logits) * span_logits.dtype.max
633
    span_logits = tf.where(tiled_token_mask, span_logits, negative_logits)
634
    # Shape of both tensors: (batch_size, max_num_slots, max_num_tokens).
635
    span_start_logits, span_end_logits = tf.unstack(span_logits, axis=3)
636
    return status_logits, span_start_logits, span_end_logits
637

638

639
# Modified from run_classifier.model_fn_builder
640
def _model_fn_builder(bert_config, init_checkpoint, learning_rate,
641
                      num_train_steps, num_warmup_steps, use_tpu,
642
                      use_one_hot_embeddings):
643
  """Returns `model_fn` closure for TPUEstimator."""
644

645
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
646
    """The `model_fn` for TPUEstimator."""
647
    is_training = (mode == tf_estimator.ModeKeys.TRAIN)
648

649
    schema_guided_dst = SchemaGuidedDST(bert_config, use_one_hot_embeddings)
650
    outputs = schema_guided_dst.define_model(features, is_training)
651
    if is_training:
652
      total_loss = schema_guided_dst.define_loss(features, outputs)
653
    else:
654
      total_loss = tf.constant(0.0)
655

656
    tvars = tf.trainable_variables()
657
    scaffold_fn = None
658
    if init_checkpoint:
659
      assignment_map, _ = modeling.get_assignment_map_from_checkpoint(
660
          tvars, init_checkpoint)
661
      if use_tpu:
662

663
        def tpu_scaffold():
664
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
665
          return tf.train.Scaffold()
666

667
        scaffold_fn = tpu_scaffold
668
      else:
669
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
670

671
    output_spec = None
672
    if mode == tf_estimator.ModeKeys.TRAIN:
673
      train_op = optimization.create_optimizer(total_loss, learning_rate,
674
                                               num_train_steps,
675
                                               num_warmup_steps, use_tpu)
676
      global_step = tf.train.get_or_create_global_step()
677
      logged_tensors = {
678
          "global_step": global_step,
679
          "total_loss": total_loss,
680
      }
681
      output_spec = tf_estimator.tpu.TPUEstimatorSpec(
682
          mode=mode,
683
          loss=total_loss,
684
          train_op=train_op,
685
          scaffold_fn=scaffold_fn,
686
          training_hooks=[
687
              tf.train.LoggingTensorHook(logged_tensors, every_n_iter=5)
688
          ])
689

690
    elif mode == tf_estimator.ModeKeys.EVAL:
691
      output_spec = tf_estimator.tpu.TPUEstimatorSpec(
692
          mode=mode, loss=total_loss, scaffold_fn=scaffold_fn)
693

694
    else:  # mode == tf.estimator.ModeKeys.PREDICT
695
      predictions = schema_guided_dst.define_predictions(features, outputs)
696
      output_spec = tf_estimator.tpu.TPUEstimatorSpec(
697
          mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
698

699
    return output_spec
700

701
  return model_fn
702

703

704
def _create_dialog_examples(processor, dial_file):
705
  """Create dialog examples and save in the file."""
706
  if not tf.io.gfile.exists(FLAGS.dialogues_example_dir):
707
    tf.io.gfile.makedirs(FLAGS.dialogues_example_dir)
708
  frame_examples = processor.get_dialog_examples(FLAGS.dataset_split)
709
  data_utils.file_based_convert_examples_to_features(frame_examples,
710
                                                     processor.dataset_config,
711
                                                     dial_file)
712

713

714
def _create_schema_embeddings(bert_config, schema_embedding_file,
715
                              dataset_config):
716
  """Create schema embeddings and save it into file."""
717
  if not tf.io.gfile.exists(FLAGS.schema_embedding_dir):
718
    tf.io.gfile.makedirs(FLAGS.schema_embedding_dir)
719
  is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2
720
  schema_emb_run_config = tf_estimator.tpu.RunConfig(
721
      master=FLAGS.master,
722
      tpu_config=tf_estimator.tpu.TPUConfig(
723
          num_shards=FLAGS.num_tpu_cores,
724
          per_host_input_for_training=is_per_host))
725

726
  schema_json_path = os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,
727
                                  "schema.json")
728
  schemas = schema.Schema(schema_json_path)
729

730
  # Prepare BERT model for embedding a natural language descriptions.
731
  bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")
732
  schema_emb_model_fn = extract_schema_embedding.model_fn_builder(
733
      bert_config=bert_config,
734
      init_checkpoint=bert_init_ckpt,
735
      use_tpu=FLAGS.use_tpu,
736
      use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
737
  # If TPU is not available, this will fall back to normal Estimator on CPU
738
  # or GPU.
739
  schema_emb_estimator = tf_estimator.tpu.TPUEstimator(
740
      use_tpu=FLAGS.use_tpu,
741
      model_fn=schema_emb_model_fn,
742
      config=schema_emb_run_config,
743
      predict_batch_size=FLAGS.predict_batch_size)
744
  vocab_file = os.path.join(FLAGS.bert_ckpt_dir, "vocab.txt")
745
  tokenizer = tokenization.FullTokenizer(
746
      vocab_file=vocab_file, do_lower_case=FLAGS.do_lower_case)
747
  emb_generator = extract_schema_embedding.SchemaEmbeddingGenerator(
748
      tokenizer, schema_emb_estimator, FLAGS.max_seq_length)
749
  emb_generator.save_embeddings(schemas, schema_embedding_file, dataset_config)
750

751

752
def main(_):
753
  vocab_file = os.path.join(FLAGS.bert_ckpt_dir, "vocab.txt")
754
  task_name = FLAGS.task_name.lower()
755
  if task_name not in config.DATASET_CONFIG:
756
    raise ValueError("Task not found: %s" % (task_name))
757
  dataset_config = config.DATASET_CONFIG[task_name]
758
  processor = data_utils.Dstc8DataProcessor(
759
      FLAGS.dstc8_data_dir,
760
      dataset_config=dataset_config,
761
      vocab_file=vocab_file,
762
      do_lower_case=FLAGS.do_lower_case,
763
      max_seq_length=FLAGS.max_seq_length,
764
      log_data_warnings=FLAGS.log_data_warnings)
765

766
  # Generate the dialogue examples if needed or specified.
767
  dial_file_name = "{}_{}_examples.tf_record".format(task_name,
768
                                                     FLAGS.dataset_split)
769
  dial_file = os.path.join(FLAGS.dialogues_example_dir, dial_file_name)
770
  if not tf.io.gfile.exists(dial_file) or FLAGS.overwrite_dial_file:
771
    tf.compat.v1.logging.info("Start generating the dialogue examples.")
772
    _create_dialog_examples(processor, dial_file)
773
    tf.compat.v1.logging.info("Finish generating the dialogue examples.")
774

775
  # Generate the schema embeddings if needed or specified.
776
  bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")
777
  tokenization.validate_case_matches_checkpoint(
778
      do_lower_case=FLAGS.do_lower_case, init_checkpoint=bert_init_ckpt)
779

780
  bert_config = modeling.BertConfig.from_json_file(
781
      os.path.join(FLAGS.bert_ckpt_dir, "bert_config.json"))
782
  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
783
    raise ValueError(
784
        "Cannot use sequence length %d because the BERT model "
785
        "was only trained up to sequence length %d" %
786
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))
787

788
  schema_embedding_file = os.path.join(
789
      FLAGS.schema_embedding_dir,
790
      "{}_pretrained_schema_embedding.npy".format(FLAGS.dataset_split))
791
  if (not tf.io.gfile.exists(schema_embedding_file) or
792
      FLAGS.overwrite_schema_emb_file):
793
    tf.compat.v1.logging.info("Start generating the schema embeddings.")
794
    _create_schema_embeddings(bert_config, schema_embedding_file,
795
                              dataset_config)
796
    tf.compat.v1.logging.info("Finish generating the schema embeddings.")
797

798
  # Create estimator for training or inference.
799
  if not tf.io.gfile.exists(FLAGS.output_dir):
800
    tf.io.gfile.makedirs(FLAGS.output_dir)
801

802
  tpu_cluster_resolver = None
803
  if FLAGS.use_tpu and FLAGS.tpu_name:
804
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
805
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
806

807
  is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2
808
  run_config = tf_estimator.tpu.RunConfig(
809
      cluster=tpu_cluster_resolver,
810
      master=FLAGS.master,
811
      model_dir=FLAGS.output_dir,
812
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
813
      keep_checkpoint_max=None,
814
      tpu_config=tf_estimator.tpu.TPUConfig(
815
          # Recommended value is number of global steps for next checkpoint.
816
          iterations_per_loop=FLAGS.save_checkpoints_steps,
817
          num_shards=FLAGS.num_tpu_cores,
818
          per_host_input_for_training=is_per_host))
819

820
  num_train_steps = None
821
  num_warmup_steps = None
822
  if FLAGS.run_mode == "train":
823
    num_train_examples = processor.get_num_dialog_examples(FLAGS.dataset_split)
824
    num_train_steps = int(num_train_examples / FLAGS.train_batch_size *
825
                          FLAGS.num_train_epochs)
826
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
827

828
  bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")
829
  model_fn = _model_fn_builder(
830
      bert_config=bert_config,
831
      init_checkpoint=bert_init_ckpt,
832
      learning_rate=FLAGS.learning_rate,
833
      num_train_steps=num_train_steps,
834
      num_warmup_steps=num_warmup_steps,
835
      use_tpu=FLAGS.use_tpu,
836
      use_one_hot_embeddings=FLAGS.use_tpu)
837

838
  # If TPU is not available, this will fall back to normal Estimator on CPU
839
  # or GPU.
840
  estimator = tf_estimator.tpu.TPUEstimator(
841
      use_tpu=FLAGS.use_tpu,
842
      model_fn=model_fn,
843
      config=run_config,
844
      train_batch_size=FLAGS.train_batch_size,
845
      eval_batch_size=FLAGS.eval_batch_size,
846
      predict_batch_size=FLAGS.predict_batch_size)
847

848
  if FLAGS.run_mode == "train":
849
    # Train the model.
850
    tf.compat.v1.logging.info("***** Running training *****")
851
    tf.compat.v1.logging.info("  Num dial examples = %d", num_train_examples)
852
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
853
    tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
854
    train_input_fn = _file_based_input_fn_builder(
855
        dataset_config=dataset_config,
856
        input_dial_file=dial_file,
857
        schema_embedding_file=schema_embedding_file,
858
        is_training=True,
859
        drop_remainder=True)
860
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
861
  elif FLAGS.run_mode == "predict":
862
    # Run inference to obtain model predictions.
863
    num_actual_predict_examples = processor.get_num_dialog_examples(
864
        FLAGS.dataset_split)
865

866
    tf.compat.v1.logging.info("***** Running prediction *****")
867
    tf.compat.v1.logging.info("  Num actual examples = %d",
868
                              num_actual_predict_examples)
869
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)
870

871
    predict_input_fn = _file_based_input_fn_builder(
872
        dataset_config=dataset_config,
873
        input_dial_file=dial_file,
874
        schema_embedding_file=schema_embedding_file,
875
        is_training=False,
876
        drop_remainder=FLAGS.use_tpu)
877

878
    input_json_files = [
879
        os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,
880
                     "dialogues_{:03d}.json".format(fid))
881
        for fid in dataset_config.file_ranges[FLAGS.dataset_split]
882
    ]
883
    schema_json_file = os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,
884
                                    "schema.json")
885

886
    ckpt_nums = [num for num in FLAGS.eval_ckpt.split(",") if num]
887
    if not ckpt_nums:
888
      raise ValueError("No checkpoints assigned for prediction.")
889
    for ckpt_num in ckpt_nums:
890
      tf.compat.v1.logging.info("***** Predict results for %s set *****",
891
                                FLAGS.dataset_split)
892

893
      predictions = estimator.predict(
894
          input_fn=predict_input_fn,
895
          checkpoint_path=os.path.join(FLAGS.output_dir,
896
                                       "model.ckpt-%s" % ckpt_num))
897

898
      # Write predictions to file in DSTC8 format.
899
      dataset_mark = os.path.basename(FLAGS.dstc8_data_dir)
900
      prediction_dir = os.path.join(
901
          FLAGS.output_dir, "pred_res_{}_{}_{}_{}".format(
902
              int(ckpt_num), FLAGS.dataset_split, task_name, dataset_mark))
903
      if not tf.io.gfile.exists(prediction_dir):
904
        tf.io.gfile.makedirs(prediction_dir)
905
      pred_utils.write_predictions_to_file(predictions, input_json_files,
906
                                           schema_json_file, prediction_dir)
907

908

909
if __name__ == "__main__":
910
  flags.mark_flag_as_required("dstc8_data_dir")
911
  flags.mark_flag_as_required("bert_ckpt_dir")
912
  flags.mark_flag_as_required("dataset_split")
913
  flags.mark_flag_as_required("schema_embedding_dir")
914
  flags.mark_flag_as_required("dialogues_example_dir")
915
  flags.mark_flag_as_required("task_name")
916
  flags.mark_flag_as_required("output_dir")
917
  tf.compat.v1.app.run(main)
918

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.