google-research

Форк
0
/
observation_sequence_model.py 
1004 строки · 45.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Observation based RNN model."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import tensorflow.compat.v1 as tf
23
from tensorflow.compat.v1 import estimator as tf_estimator
24

25
from explaining_risk_increase import input_fn
26
from tensorflow.contrib import estimator as contrib_estimator
27
from tensorflow.contrib import lookup as contrib_lookup
28
from tensorflow.contrib import rnn as contrib_rnn
29
from tensorflow.contrib import training as contrib_training
30
from tensorflow.contrib.learn.python.learn.estimators import rnn_common
31

32

33
TOLERANCE = 0.2
34

35

36
class PredictionKeys(object):
37
  """Enum for prediction keys."""
38
  LOGITS = 'logits'
39
  PROBABILITIES = 'probs'
40
  CLASSES = 'classes'
41

42

43
def _most_recent_obs_value(obs_values, indicator, delta_time,
44
                           attribution_max_delta_time):
45
  """Returns the most recent lab result for each test within a time frame.
46

47
  The eligible lab values fall into a time window until time of prediction -
48
  attribution_max_delta_time. Among those we select their most recent value
49
  or zero if there are none.
50

51
  Args:
52
    obs_values: A dense representation of the observation_values at the position
53
      of their obs_code_ids. A padded Tensor of shape [batch_size,
54
      max_sequence_length, vocab_size] of type float32 where obs_values[b, t,
55
      id] = observation_values[b, t, 0] and id = observation_code_ids[b, t, 0]
56
      and obs_values[b, t, x] = 0 for all other x != id. If t is greater than
57
      the sequence_length of batch entry b then the result is 0 as well.
58
    indicator: A one-hot encoding of whether a value in obs_values comes from
59
      observation_values or is just filled in to be 0. A Tensor of shape
60
      [batch_size, max_sequence_length, vocab_size] and type float32.
61
    delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
62
      the time to prediction.
63
    attribution_max_delta_time: Time threshold so that we return the most recent
64
      lab values among those that are at least attribution_max_delta_time
65
      seconds old at time of prediction.
66

67
  Returns:
68
    A Tensor of shape [batch_size, 1, vocab_size] of the most recent lab results
69
    for all lab tests that are at least attribution_max_delta_time old at time
70
    of prediction.
71
  """
72
  batch_size = tf.shape(indicator)[0]
73
  seq_len = tf.shape(indicator)[1]
74
  num_obs = indicator.shape[2]
75
  # Prepend a dummy so that for lab tests for which we have no eligible lab
76
  # values we will select 0.
77
  obs_values = tf.concat(
78
      [tf.zeros([batch_size, 1, num_obs]), obs_values], axis=1)
79
  indicator = tf.concat([tf.ones([batch_size, 1, num_obs]), indicator], axis=1)
80
  delta_time = tf.to_int32(delta_time)
81
  delta_time = tf.concat(
82
      [
83
          tf.zeros([batch_size, 1, 1], dtype=tf.int32) +
84
          attribution_max_delta_time, delta_time
85
      ],
86
      axis=1)
87
  # First we figure out what the eligible lab values are that are at least
88
  # attribution_max_delta_time old.
89
  indicator = tf.to_int32(indicator)
90
  indicator *= tf.to_int32(delta_time >= attribution_max_delta_time)
91
  range_val = tf.expand_dims(tf.range(seq_len + 1), axis=0)
92
  range_val = tf.tile(range_val, multiples=[tf.shape(indicator)[0], 1])
93
  # [[[0], [1], ..., [max_sequence_length]],
94
  #  [[0], [1], ..., [max_sequence_length]],
95
  #  ...]
96
  range_val = tf.expand_dims(range_val, axis=2)
97
  # [batch_size, max_sequence_length, vocab_size] with 1 non-zero number per
98
  # time-step equal to that time-step.
99
  seq_indicator = indicator * range_val
100
  # [batch_size, vocab_size] with the time-step of the last lab value.
101
  last_val_indicator = tf.reduce_max(seq_indicator, axis=1, keepdims=True)
102
  last_val_indicator = tf.tile(
103
      last_val_indicator, multiples=[1, tf.shape(indicator)[1], 1])
104

105
  # eq indicates which lab values are the most recent ones.
106
  eq = tf.logical_and(
107
      tf.equal(last_val_indicator, seq_indicator), indicator > 0)
108
  most_recent_obs_value_indicator = tf.where(eq)
109
  # Collect the lab values associated with those indices.
110
  res = tf.gather_nd(obs_values, most_recent_obs_value_indicator)
111
  # Reorder the values by batch and then by lab test.
112
  res_sorted = tf.sparse_reorder(
113
      tf.sparse_transpose(
114
          tf.SparseTensor(
115
              indices=most_recent_obs_value_indicator,
116
              values=res,
117
              dense_shape=tf.to_int64(
118
                  tf.stack([batch_size, seq_len + 1, num_obs]))),
119
          perm=[0, 2, 1])).values
120

121
  return tf.reshape(res_sorted, [batch_size, 1, num_obs])
122

123

124
def _predictions_for_gradients(predictions, seq_mask, delta_time,
125
                               attribution_max_delta_time, averaged):
126
  """Aggregates eligible predictions over time.
127

128
  Predictions are eligible if their are within the sequence_length (as indicated
129
  by seq_mask) and their associated delta_time is at most
130
  attribution_max_delta_time.
131
  Eligible predictions are either averaged across those eligble times (if
132
  averaged=True) or summed otherwise.
133

134
  Args:
135
    predictions: A Tensor of shape [batch_size, max_seq_len, 1]
136
      with the predictions in the sequence.
137
    seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
138
      which timesteps are padded.
139
    delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
140
      the time to prediction.
141
    attribution_max_delta_time: Attribution is limited to values that are no
142
      older than that many seconds at time of prediction.
143
    averaged: Whether predictions are simply summed up across the time-steps
144
      or averaged over on the sequence length.
145
  Returns:
146
    A Tensor of shape [batch, 1, 1] of the eligible predictions
147
    aggregated across time.
148
  """
149
  mask = seq_mask * tf.to_float(delta_time < attribution_max_delta_time)
150
  predictions *= mask
151
  if averaged:
152
    predictions /= tf.reduce_sum(mask, axis=1, keepdims=True)
153
  return tf.reduce_sum(predictions, axis=1, keepdims=True)
154

155

156
def compute_gradient_attribution(predictions, obs_values, indicator):
157
  """Constructs the attribution of what inputs result in a higher prediction.
158

159
  Attribution here refers to the timesteps in which the predictions (derived
160
  from the logits) increased. We are only interested in increases in the
161
  previous 12h.
162

163
  Args:
164
    predictions: A Tensor of shape [batch_size, 1, 1] with the
165
      predictions in the sequence.
166
    obs_values: A dense representation of the observation_values with
167
      obs_values[b, t, :] has at most one non-zero value at the position
168
      of the corresponding lab test from obs_code_ids with the value of the lab
169
      result. A padded Tensor of shape [batch_size, max_sequence_length,
170
      vocab_size] of type float32 of possibly normalized observation values.
171
    indicator: A one-hot encoding of whether a value in obs_values comes from
172
      observation_values or is just filled in to be 0. A Tensor of
173
      shape [batch_size, max_sequence_length, vocab_size] and type float32.
174
  Returns:
175
    A Tensor of shape [batch, max_sequence_length, 1] of the gradient of the
176
    prediction as a function of the lab result at that batch-entry time.
177
  """
178
  attr = tf.gradients(tf.squeeze(predictions, axis=1,
179
                                 name='squeeze_pred_for_gradients'),
180
                      [obs_values])[0]
181
  # Zero-out gradients for other lab-tests and then sum up across lab tests
182
  # for which at most one gradient will be non-zero.
183
  attr *= indicator
184
  attr = tf.reduce_sum(attr, axis=2, keepdims=True)
185
  return attr
186

187

188
def compute_path_integrated_gradient_attribution(
189
    obs_values,
190
    indicator,
191
    diff_delta_time,
192
    delta_time,
193
    sequence_length,
194
    seq_mask,
195
    hparams,
196
    construct_logits_fn=None):
197
  """Constructs the attribution of what inputs result in a higher prediction.
198

199
  Attribution here refers to the integrated gradients as defined here
200
  https://arxiv.org/pdf/1703.01365.pdf and approximated for the j-th variable
201
  via
202

203
  (x-x') * 1/num_steps * sum_{i=1}^{num_steps} of the derivative of
204
  F(x'+(x-x')*i/num_steps) w.r.t. its j-th input.
205

206
  where we take x' the most recent value before attribution_max_delta_time and
207
  x to be the subsequent observation values from the same lab test.
208
  x'+(x-x')*i/num_steps is the linear interpolation between x' and x.
209

210
  Args:
211
    obs_values: A dense representation of the observation_values with
212
      obs_values[b, t, :] has at most one non-zero value at the position
213
      of the corresponding lab test from obs_code_ids with the value of the lab
214
      result. A padded Tensor of shape [batch_size, max_sequence_length,
215
      vocab_size] of type float32 of possibly normalized observation values.
216
    indicator: A one-hot encoding of whether a value in obs_values comes from
217
      observation_values or is just filled in to be 0. A Tensor of
218
      shape [batch_size, max_sequence_length, vocab_size] and type float32.
219
    diff_delta_time: Difference between two consecutive time steps.
220
    delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
221
      the time to prediction.
222
    sequence_length: Sequence length (before padding), Tensor of shape
223
      [batch_size].
224
    seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1]
225
      indicating which timesteps are padded.
226
    hparams: Hyper parameters.
227
    construct_logits_fn: A method with constructing the logits given input as
228
      construct_logits. If None using construct_logits.
229
  Returns:
230
    A Tensor of shape [batch, max_sequence_length, 1] of the gradient of the
231
    prediction as a function of the lab result at that batch-entry time.
232
  """
233
  last_obs_values_0 = _most_recent_obs_value(obs_values, indicator, delta_time,
234
                                             hparams.attribution_max_delta_time)
235
  gradients = []
236
  # We need to limit the diff over the base to timesteps after base.
237
  last_obs_values = last_obs_values_0 * (
238
      tf.to_float(indicator) *
239
      tf.to_float(delta_time < hparams.attribution_max_delta_time))
240
  obs_values_with_last_replaced = obs_values * tf.to_float(
241
      delta_time >= hparams.attribution_max_delta_time) + last_obs_values
242
  diff_over_base = obs_values - obs_values_with_last_replaced
243

244
  for i in range(hparams.path_integrated_gradients_num_steps):
245
    alpha = 1.0 * i / (hparams.path_integrated_gradients_num_steps - 1)
246
    step_obs_values = obs_values_with_last_replaced + diff_over_base * alpha
247
    if not construct_logits_fn:
248
      construct_logits_fn = construct_logits
249
    logits, _ = construct_logits_fn(
250
        diff_delta_time,
251
        step_obs_values,
252
        indicator,
253
        sequence_length,
254
        seq_mask,
255
        hparams,
256
        reuse=True)
257
    if hparams.use_rnn_attention:
258
      last_logits = logits
259
    else:
260
      last_logits = rnn_common.select_last_activations(
261
          logits, tf.to_int32(sequence_length))
262
    # Ideally, we'd like to get the gradients of the change in
263
    # value over the previous one to attribute it to both and not just a single
264
    # value.
265
    gradient = compute_gradient_attribution(last_logits, step_obs_values,
266
                                            indicator)
267
    gradients.append(
268
        tf.reduce_sum(diff_over_base, axis=2, keepdims=True) * gradient)
269
  return tf.add_n(gradients) / tf.to_float(
270
      hparams.path_integrated_gradients_num_steps)
271

272

273
def compute_attention(seq_output, last_output, hidden_layer_dim, seq_mask,
274
                      sequence_length):
275
  """Constructs attention of the last_output as query and the sequence output.
276

277
  The attention is the dot-product of the last_output (the final RNN output),
278
  with the seq_output (the RNN's output at each step). Here the final RNN output
279
  is considered as the "query" or "context" vector. The final attention output
280
  is a weighted sum of the RNN's outputs at all steps. Details:
281

282
    alpha_i = seq_output_i * last_output
283
    beta is then obtained by normalizing alpha:
284
    beta_i = exp(alpha_i) / sum_j exp(alpha_j)
285
    The new attention vector is then the beta-weighted sum over the seq_output:
286
    attention_vector = sum_i beta_i * seq_output_i
287

288
  If hidden_dim > 0 then before computing alpha the seq_output and the
289
  last_output are sent through two separate hidden layers.
290
  seq_output = hidden_layer(seq_output)
291
  last_output = hidden_layer(last_output)
292

293
  Args:
294
    seq_output: The raw rnn output of shape [batch_size, max_sequence_length,
295
      rnn_size].
296
    last_output: The last output of the rnn of shape [batch_size, rnn_size].
297
    hidden_layer_dim: If 0 no hidden layer is applied before multiplying the
298
      last_logits with the seq_logits.
299
    seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
300
      which timesteps are padded.
301
    sequence_length: Sequence length (before padding), Tensor of shape
302
      [batch_size].
303

304
  Returns:
305
    Attention output with shape [batch_size, rnn_size].
306
    The attention beta tensor.
307
  """
308
  # Compute the weights.
309
  if hidden_layer_dim > 0:
310
    last_output = tf.layers.dense(
311
        last_output, hidden_layer_dim, activation=tf.nn.relu6)
312
    seq_output = tf.layers.dense(
313
        seq_output, hidden_layer_dim, activation=tf.nn.relu6)
314
  last_output = tf.expand_dims(last_output, 1)  # [batch_size, 1, rnn_size]
315
  tmp = tf.multiply(seq_output, last_output)  # dim 1: broadcast
316
  alpha_tensor = tf.reduce_sum(tmp, 2)  # [b, max_seq_len]
317
  alpha_tensor *= tf.squeeze(seq_mask, axis=2)
318
  beta_tensor = tf.nn.softmax(alpha_tensor)  # using default dim -1
319
  beta_tensor = tf.expand_dims(beta_tensor, -1)  # [b, max_seq_len, 1]
320

321
  # Compute weighted sum of the original rnn_outputs over all steps
322
  tmp = seq_output * beta_tensor  # last dim: use "broadcast"
323
  rnn_outputs_weighted_sum = tf.reduce_sum(tmp, 1)  # [b, rnn_size]
324
  last_beta = rnn_common.select_last_activations(
325
      beta_tensor, tf.to_int32(sequence_length))
326
  tf.summary.histogram('last_beta_attention', last_beta)
327

328
  return rnn_outputs_weighted_sum, beta_tensor
329

330

331
def compute_prediction_diff_attribution(logits):
332
  """Constructs the attribution of what inputs result in a higher prediction.
333

334
  Attribution here refers to the timesteps in which the predictions (derived
335
  from the logits) increased.
336

337
  Args:
338
    logits: The logits of the model_fn.
339
  Returns:
340
    A Tensor of shape [batch_size, max_sequence_length, 1] with an attribution
341
    value at time t of prediction at time t minus prediction at time t-1.
342
  """
343
  predictions = tf.sigmoid(logits)
344
  shape = tf.shape(logits)
345
  zeros = tf.zeros(shape=[shape[0], 1, shape[2]], dtype=tf.float32)
346
  # Our basic notion of attribution at timestep i is how much the predicted
347
  # risk increased at that time compared to the previous prediction.
348
  return predictions - tf.concat(
349
      [zeros, predictions[:, :-1, :]], axis=1, name='attribution')
350

351

352
def convert_attribution(attribution, sequence_feature_map, seq_mask, delta_time,
353
                        attribution_threshold, attribution_max_delta_time,
354
                        prefix=''):
355
  """Constructs the attribution of what inputs result in a higher prediction.
356

357
  Attribution here refers to the timesteps in which the predictions (derived
358
  from the logits) increased. We are only interested in increases in the
359
  previous attribution_max_delta_time.
360

361
  Args:
362
    attribution: A Tensor of shape [batch, max_sequence_length, 1] computed
363
      using some attribution method.
364
    sequence_feature_map: A dictionary from name to (Sparse)Tensor.
365
    seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
366
      which timesteps are padded.
367
    delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
368
      the time to prediction.
369
    attribution_threshold: Attribution values below this threshold will be
370
      dropped.
371
    attribution_max_delta_time: Attribution is limited to values that are no
372
      older than that many seconds at time of prediction.
373
    prefix: A string to prepend to the feature names for the attribution_dict.
374
  Returns:
375
    A dictionary from feature names to SparseTensors of
376
    dense_shape [batch_size, max_sequence_length, 1].
377
  """
378
  # We do not want attribution in the padding.
379
  attribution *= seq_mask
380

381
  # We focus on attribution in the past 12h.
382
  # [batch_size, max_sequence_length, 1]
383
  attribution *= tf.to_float(delta_time < attribution_max_delta_time)
384

385
  # We get rid of low attribution.
386
  attribution_indices = tf.where(attribution > attribution_threshold)
387
  attribution_values = tf.gather_nd(attribution, attribution_indices)
388

389
  # Now, attribution.indices indicate in the input timesteps which we should
390
  # attend to.
391
  attribution_dict = {}
392
  for feature, sp_feature in sequence_feature_map.items():
393
    # Limitation: This is not going to work for sequence feature in which
394
    # the third (last/token) dimension is > 1. In that case only the first
395
    # token would be highlighted.
396
    attribution_dict[prefix + feature] = tf.sparse.expand_dims(
397
        tf.SparseTensor(
398
            indices=attribution_indices,
399
            values=attribution_values,
400
            dense_shape=tf.to_int64(tf.shape(sp_feature))), axis=1)
401
  return attribution_dict
402

403

404
def normalize_each_feature(observation_values, obs_code, vocab_size, mode,
405
                           momentum):
406
  """Combines SparseTensors of observation codes and values into a Tensor.
407

408
  Args:
409
    observation_values: A SparseTensor of type float with the observation
410
      values of dense shape [batch_size, max_sequence_length, 1].
411
      There may be no time gaps in between codes.
412
    obs_code: A Tensor of shape [?, 3] of type int32 with the ids that go along
413
      with the observation_values. We will do the normalization separately for
414
      each lab test.
415
    vocab_size: The range of the values in obs_code is from 0 to vocab_size.
416
    mode: The execution mode, as defined in tf.estimator.ModeKeys.
417
    momentum: Mean and variance will be updated as
418
      momentum*old_value + (1-momentum) * new_value.
419
  Returns:
420
    observation_values as in the input only with normalized values.
421
  """
422
  with tf.variable_scope('batch_normalization'):
423
    new_indices = []
424
    new_values = []
425

426
    for i in range(vocab_size):
427
      with tf.variable_scope('bn' + str(i)):
428
        positions_of_feature_i = tf.where(tf.equal(obs_code, i))
429
        values_of_feature_i = tf.gather_nd(observation_values.values,
430
                                           positions_of_feature_i)
431
        if mode == tf_estimator.ModeKeys.TRAIN:
432
          tf.summary.scalar('avg_observation_values/' + str(i),
433
                            tf.reduce_mean(values_of_feature_i))
434
          tf.summary.histogram('observation_values/' + str(i),
435
                               values_of_feature_i)
436
        batchnorm_layer = tf.layers.BatchNormalization(
437
            axis=1,
438
            momentum=momentum,
439
            epsilon=0.01,
440
            trainable=True)
441
        normalized_values = tf.squeeze(
442
            batchnorm_layer.apply(
443
                tf.expand_dims(values_of_feature_i, axis=1),
444
                training=(mode == tf_estimator.ModeKeys.TRAIN)
445
            ),
446
            axis=1,
447
            name='squeeze_normalized_values')
448
        if mode == tf_estimator.ModeKeys.TRAIN:
449
          tf.summary.scalar('batchnorm_layer/moving_mean/' + str(i),
450
                            tf.squeeze(batchnorm_layer.moving_mean))
451
          tf.summary.scalar('batchnorm_layer/moving_variance/' + str(i),
452
                            tf.squeeze(batchnorm_layer.moving_variance))
453
          tf.summary.scalar('avg_normalized_values/' + str(i),
454
                            tf.reduce_mean(normalized_values))
455
          tf.summary.histogram('normalized_observation_values/' + str(i),
456
                               normalized_values)
457
        indices_i = tf.gather_nd(observation_values.indices,
458
                                 positions_of_feature_i)
459
        new_indices += [indices_i]
460
        normalized_values = tf.where(tf.is_nan(normalized_values),
461
                                     tf.zeros_like(normalized_values),
462
                                     normalized_values)
463
        new_values += [normalized_values]
464

465
    normalized_sp_tensor = tf.SparseTensor(
466
        indices=tf.concat(new_indices, axis=0),
467
        values=tf.concat(new_values, axis=0),
468
        dense_shape=observation_values.dense_shape)
469
    normalized_sp_tensor = tf.sparse_reorder(normalized_sp_tensor)
470
    return normalized_sp_tensor
471

472

473
def combine_observation_code_and_values(observation_code_ids,
474
                                        observation_values, vocab_size, mode,
475
                                        normalize, momentum, min_value,
476
                                        max_value):
477
  """Combines SparseTensors of observation codes and values into a Tensor.
478

479
  Args:
480
    observation_code_ids: A SparseTensor of type int32 with the ids of the
481
      observation codes of dense shape [batch_size, max_sequence_length, 1].
482
      There may be no time gaps in between codes.
483
    observation_values: A SparseTensor of type float with the observation
484
      values of dense shape [batch_size, max_sequence_length, 1].
485
      There may be no time gaps in between codes.
486
    vocab_size: The range of the values in obs_code_ids is from 0 to vocab_size.
487
    mode: The execution mode, as defined in tf.estimator.ModeKeys.
488
    normalize: Whether to normalize each lab test.
489
    momentum: For the batch normalization mean and variance will be updated as
490
      momentum*old_value + (1-momentum) * new_value.
491
    min_value: Observation values smaller than this will be capped to min_value.
492
    max_value: Observation values larger than this will be capped to max_value.
493

494
  Returns:
495
    - obs_values: A dense representation of the observation_values at the
496
                  position of their obs_code_ids. A padded Tensor of shape
497
                  [batch_size, max_sequence_length, vocab_size] of type float32
498
                  where obs_values[b, t, id] = observation_values[b, t, 0] and
499
                  id = observation_code_ids[b, t, 0] and obs_values[b, t, x] = 0
500
                  for all other x != id. If t is greater than the
501
                  sequence_length of batch entry b then the result is 0 as well.
502
    - indicator: A one-hot encoding of whether a value in obs_values comes from
503
                 observation_values or is just filled in to be 0. A Tensor of
504
                 shape [batch_size, max_sequence_length, vocab_size] and type
505
                 float32.
506
  """
507
  obs_code = observation_code_ids.values
508
  if normalize:
509
    with tf.variable_scope('values'):
510
      observation_values = normalize_each_feature(
511
          observation_values, obs_code, vocab_size, mode, momentum)
512
  observation_values_rank2 = tf.SparseTensor(
513
      values=observation_values.values,
514
      indices=observation_values.indices[:, 0:2],
515
      dense_shape=observation_values.dense_shape[0:2])
516
  obs_indices = tf.concat(
517
      [observation_values_rank2.indices,
518
       tf.expand_dims(obs_code, axis=1)],
519
      axis=1, name='obs_indices')
520
  obs_shape = tf.concat(
521
      [observation_values_rank2.dense_shape, [vocab_size]], axis=0,
522
      name='obs_shape')
523

524
  obs_values = tf.sparse_to_dense(obs_indices, obs_shape,
525
                                  observation_values_rank2.values)
526
  obs_values.set_shape([None, None, vocab_size])
527
  indicator = tf.sparse_to_dense(obs_indices, obs_shape,
528
                                 tf.ones_like(observation_values_rank2.values))
529
  indicator.set_shape([None, None, vocab_size])
530
  # clip
531
  obs_values = tf.minimum(obs_values, max_value)
532
  obs_values = tf.maximum(obs_values, min_value)
533
  return obs_values, indicator
534

535

536
def construct_input(sequence_feature_map, categorical_values,
537
                    categorical_seq_feature, feature_value, mode, normalize,
538
                    momentum, min_value, max_value, input_keep_prob):
539
  """Returns a function to build the model.
540

541
  Args:
542
    sequence_feature_map: A dictionary of (Sparse)Tensors of dense shape
543
      [batch_size, max_sequence_length, None] keyed by the feature name.
544
    categorical_values: Potential values of the categorical_seq_feature.
545
    categorical_seq_feature: Name of feature of observation code.
546
    feature_value: Name of feature of observation value.
547
    mode: The execution mode, as defined in tf.estimator.ModeKeys.
548
    normalize: Whether to normalize each lab test.
549
    momentum: For the batch normalization mean and variance will be updated as
550
      momentum*old_value + (1-momentum) * new_value.
551
    min_value: Observation values smaller than this will be capped to min_value.
552
    max_value: Observation values larger than this will be capped to max_value.
553
    input_keep_prob: Keep probability for input observation values.
554

555
  Returns:
556
    - diff_delta_time: Tensor of shape [batch_size, max_seq_length, 1]
557
      with the
558
    - obs_values: A dense representation of the observation_values with
559
                  obs_values[b, t, :] has at most one non-zero value at the
560
                  position of the corresponding lab test from obs_code_ids with
561
                  the value of the lab result. A padded Tensor of shape
562
                  [batch_size, max_sequence_length, vocab_size] of type float32
563
                  of possibly normalized observation values.
564
    - indicator: A one-hot encoding of whether a value in obs_values comes from
565
                 observation_values or is just filled in to be 0. A Tensor of
566
                 shape [batch_size, max_sequence_length, vocab_size] and type
567
                 float32.
568
  """
569
  with tf.variable_scope('input'):
570
    sequence_feature_map = {
571
        k: tf.sparse_reorder(s) if isinstance(s, tf.SparseTensor) else s
572
        for k, s in sequence_feature_map.items()
573
    }
574
    # Filter out invalid values.
575
    # For invalid observation values we do this through a sparse retain.
576
    # This makes sure that the invalid values will not be considered in the
577
    # normalization.
578
    observation_values = sequence_feature_map[feature_value]
579
    observation_code_sparse = sequence_feature_map[categorical_seq_feature]
580
    # Future work: Create a flag for the missing value indicator.
581
    valid_values = tf.abs(observation_values.values - 9999999.0) > TOLERANCE
582
    # apply input dropout
583
    if input_keep_prob < 1.0:
584
      random_tensor = input_keep_prob
585
      random_tensor += tf.random_uniform(tf.shape(observation_values.values))
586
      # 0. if [input_keep_prob, 1.0) and 1. if [1.0, 1.0 + input_keep_prob)
587
      dropout_mask = tf.floor(random_tensor)
588
      if mode == tf_estimator.ModeKeys.TRAIN:
589
        valid_values = tf.to_float(valid_values) * dropout_mask
590
        valid_values = valid_values > 0.5
591
    sequence_feature_map[feature_value] = tf.sparse_retain(
592
        observation_values, valid_values)
593
    sequence_feature_map[categorical_seq_feature] = tf.sparse_retain(
594
        observation_code_sparse, valid_values)
595

596
    # 1. Construct the sequence of observation values to feed into the RNN
597
    #    and their indicator.
598
    # We assign each observation code an id from 0 to vocab_size-1. At each
599
    # timestep we will lookup the id for the observation code and take the value
600
    # of the lab test and a construct a vector with all zeros but the id-th
601
    # position is set to the lab test value.
602
    obs_code = sequence_feature_map[categorical_seq_feature]
603
    obs_code_dense_ids = contrib_lookup.index_table_from_tensor(
604
        tuple(categorical_values), num_oov_buckets=0,
605
        name='vocab_lookup').lookup(obs_code.values)
606
    obs_code_sparse = tf.SparseTensor(
607
        values=obs_code_dense_ids,
608
        indices=obs_code.indices,
609
        dense_shape=obs_code.dense_shape)
610
    obs_code_sparse = tf.sparse_reorder(obs_code_sparse)
611
    observation_values = sequence_feature_map[feature_value]
612
    observation_values = tf.sparse_reorder(observation_values)
613
    vocab_size = len(categorical_values)
614
    obs_values, indicator = combine_observation_code_and_values(
615
        obs_code_sparse, observation_values, vocab_size, mode, normalize,
616
        momentum, min_value, max_value)
617

618
    # 2. We compute the diff_delta_time as additional sequence feature.
619
    # Note, the LSTM is very sensitive to how you encode time.
620
    delta_time = sequence_feature_map['deltaTime']
621
    diff_delta_time = tf.concat(
622
        [delta_time[:, :1, :], delta_time[:, :-1, :]], axis=1) - delta_time
623
    diff_delta_time = tf.to_float(diff_delta_time) / (60.0 * 60.0)
624

625
  return (diff_delta_time, obs_values, indicator)
626

627

628
def construct_rnn_logits(diff_delta_time,
629
                         obs_values,
630
                         indicator,
631
                         sequence_length,
632
                         rnn_size,
633
                         variational_recurrent_keep_prob,
634
                         variational_input_keep_prob,
635
                         variational_output_keep_prob,
636
                         reuse=False):
637
  """Computes logits combining inputs and applying an RNN.
638

639
  Args:
640
   diff_delta_time: Difference between two consecutive time steps.
641
   obs_values: A dense representation of the observation_values with
642
      obs_values[b, t, :] has at most one non-zero value at the position
643
      of the corresponding lab test from obs_code_ids with the value of the lab
644
      result. A padded Tensor of shape [batch_size, max_sequence_length,
645
      vocab_size] of type float32 of possibly normalized observation values.
646
    indicator: A one-hot encoding of whether a value in obs_values comes from
647
      observation_values or is just filled in to be 0. A Tensor of
648
      shape [batch_size, max_sequence_length, vocab_size] and type float32.
649
    sequence_length: Sequence length (before padding), Tensor of shape
650
      [batch_size].
651
    rnn_size: Size of the LSTM hidden state and output.
652
    variational_recurrent_keep_prob: 1 - droput for the hidden LSTM state.
653
    variational_input_keep_prob: 1 - dropout for the input to the LSTM.
654
    variational_output_keep_prob: 1 - dropout for the output of the LSTM.
655
    reuse: Whether to reuse existing variables or setup new ones.
656

657
  Returns:
658
    logits a Tensor of shape [batch_size, max_sequence_length, 1].
659
  """
660
  with tf.variable_scope('logits/rnn', reuse=reuse) as sc:
661
    rnn_inputs = [diff_delta_time, indicator, obs_values]
662
    sequence_data = tf.concat(rnn_inputs, axis=2, name='rnn_input')
663

664
    # Run a recurrent neural network across the time dimension.
665
    cell = contrib_rnn.LSTMCell(rnn_size, state_is_tuple=True)
666
    if (variational_recurrent_keep_prob < 1 or variational_input_keep_prob < 1
667
        or variational_output_keep_prob < 1):
668
      cell = tf.nn.rnn_cell.DropoutWrapper(
669
          cell, input_keep_prob=variational_input_keep_prob,
670
          output_keep_prob=variational_output_keep_prob,
671
          state_keep_prob=variational_recurrent_keep_prob,
672
          variational_recurrent=True, input_size=tf.shape(sequence_data)[2],
673
          seed=12345678)
674

675
    output, _ = tf.nn.dynamic_rnn(
676
        cell,
677
        sequence_data,
678
        sequence_length=sequence_length,
679
        dtype=tf.float32,
680
        swap_memory=True,
681
        scope='rnn')
682

683
    # 3. Make a time-series of logits via a linear-mapping of the rnn-output
684
    #    to logits_dimension = 1.
685
    return tf.layers.dense(
686
        output, 1, name=sc, reuse=reuse, activation=None), output
687

688

689
def construct_logits(diff_delta_time, obs_values, indicator, sequence_length,
690
                     seq_mask, hparams, reuse):
691
  """Constructs logits through an RNN.
692

693
  Args:
694
    diff_delta_time: Difference between two consecutive time steps.
695
    obs_values: A dense representation of the observation_values with
696
      obs_values[b, t, :] has at most one non-zero value at the position
697
      of the corresponding lab test from obs_code_ids with the value of the lab
698
      result. A padded Tensor of shape [batch_size, max_sequence_length,
699
      vocab_size] of type float32 of possibly normalized observation values.
700
    indicator: A one-hot encoding of whether a value in obs_values comes from
701
      observation_values or is just filled in to be 0. A Tensor of
702
      shape [batch_size, max_sequence_length, vocab_size] and type float32.
703
    sequence_length: Sequence length (before padding), Tensor of shape
704
      [batch_size].
705
    seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
706
      which timesteps are padded.
707
    hparams: Hyper parameters.
708
    reuse: Boolean indicator of whether to re-use the variables.
709

710
  Returns:
711
    - Logits: A Tensor of shape [batch, {max_sequence_length,1}, 1].
712
    - Weights: Defaults to None. Only populated to a Tensor of shape
713
               [batch, max_sequence_length, 1] if
714
               hparams.use_rnn_attention is True.
715
  """
716

717
  logits, raw_output = construct_rnn_logits(
718
      diff_delta_time, obs_values, indicator, sequence_length, hparams.rnn_size,
719
      hparams.variational_recurrent_keep_prob,
720
      hparams.variational_input_keep_prob, hparams.variational_output_keep_prob,
721
      reuse)
722
  if hparams.use_rnn_attention:
723
    with tf.variable_scope('logits/rnn/attention', reuse=reuse) as sc:
724
      last_logits = rnn_common.select_last_activations(
725
          raw_output, tf.to_int32(sequence_length))
726
      weighted_final_output, weight = compute_attention(
727
          raw_output, last_logits, hparams.attention_hidden_layer_dim,
728
          seq_mask, sequence_length)
729
      return tf.layers.dense(
730
          weighted_final_output, 1, name=sc, reuse=reuse,
731
          activation=None), weight
732
  else:
733
    return logits, None
734

735

736
class ObservationSequenceModel(object):
737
  """Model that runs an RNN over the time series of observation values.
738

739
  Consider a single lab (e.g. heart rate) and its value (e.g. 60) at a time.
740
  The input to the RNN at that timestep will have a value of 60 at the unique
741
  position for heart rate. The positions of all other lab tests will be 0.
742

743
  Additional input to the RNN include an indicator (to be able to distinguish a
744
  true lab measurement of 0 from the padded ones) and a notion of time
745
  (in particular how many hours have passed since the previous time-step).
746

747
  Caution: This model can only be run on condensed SequenceExample with an
748
  observation present each time step.
749
  """
750

751
  def create_model_hparams(self):
752
    """Returns default hparams for observation sequence model."""
753
    categorical_values_str = 'loinc:2823-3,loinc:2160-0,loinc:804-5,loinc:3094-0,loinc:786-4,loinc:2075-0,loinc:2951-2,loinc:34728-6,mimic3:observation_code:834,mimic3:observation_code:678,loinc:2345-7,mimic3:observation_code:3603,mimic3:observation_code:223761,loinc:3173-2,loinc:5895-7,loinc:5902-2,loinc:2601-3,loinc:2000-8,loinc:2777-1,mimic3:observation_code:3655,loinc:32693-4,mimic3:observation_code:679,mimic3:observation_code:676,loinc:2339-0,loinc:1994-3,mimic3:observation_code:224690,loinc:1975-2,loinc:1742-6,loinc:1920-8,loinc:6768-6,mimic3:observation_code:3312,mimic3:observation_code:8502,mimic3:observation_code:3313,loinc:1751-7,loinc:6598-7,mimic3:observation_code:225309,mimic3:observation_code:225310,mimic3:observation_code:40069,loinc:3016-3,loinc:1968-7,loinc:4548-4,loinc:2093-3,loinc:2085-9,loinc:2090-9,mimic3:observation_code:6701,mimic3:observation_code:8555,mimic3:observation_code:6702,loinc:10839-9,mimic3:observation_code:3318,mimic3:observation_code:3319'
754
    return contrib_training.HParams(
755
        context_features=['sequenceLength'],
756
        batch_size=128,
757
        learning_rate=0.002,
758
        sequence_features=[
759
            'deltaTime',
760
            'Observation.code',
761
            'Observation.valueQuantity.value',
762
            'Observation.valueQuantity.unit',
763
            'Observation.code.harmonized:valueset-observation-name',
764
        ],
765
        feature_value='Observation.valueQuantity.value',
766
        categorical_values=categorical_values_str.split(','),
767
        categorical_seq_feature='Observation.code',
768
        label_key='label.in_hospital_death',
769
        input_keep_prob=1.0,
770
        attribution_threshold=0.0001,
771
        attribution_max_delta_time=12 * 60 * 60,
772
        rnn_size=64,
773
        variational_recurrent_keep_prob=0.99,
774
        variational_input_keep_prob=0.97,
775
        variational_output_keep_prob=0.98,
776
        sequence_prediction=False,
777
        normalize=True,
778
        momentum=0.75,
779
        min_value=-1000.0,
780
        max_value=1000.0,
781
        # If sequence_prediction is True then the loss will also include the
782
        # sum of the changes in predictions across the sequence as a way to
783
        # learn models with less volatile predictions.
784
        volatility_loss_factor=0.0,
785
        include_sequence_prediction=True,
786
        include_gradients_attribution=True,
787
        include_gradients_sum_time_attribution=False,
788
        include_gradients_avg_time_attribution=False,
789
        include_path_integrated_gradients_attribution=True,
790
        use_rnn_attention=False,
791
        attention_hidden_layer_dim=0,
792
        include_diff_sequence_prediction_attribution=True,
793
        # If include_path_integrated_gradients_attribution determines the number
794
        # of steps between the old and the current observation value.
795
        path_integrated_gradients_num_steps=10,
796
    )
797

798
  def create_model_fn(self, hparams):
799
    """Returns a function to build the model.
800

801
    Args:
802
      hparams: The hyperparameters.
803

804
    Returns:
805
      A function to build the model's graph. This function is called by
806
      the Estimator object to construct the graph.
807
    """
808

809
    def model_fn(features, labels, mode):
810
      """Creates the prediction, loss, and train ops.
811

812
      Args:
813
        features: A dictionary of tensors keyed by the feature name.
814
        labels: A dictionary of label tensors keyed by the label key.
815
        mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.
816

817
      Returns:
818
        EstimatorSpec with the mode, prediction, loss, train_op and
819
        output_alternatives a dictionary specifying the output for a
820
        servo request during serving.
821
      """
822
      # 1. Construct input to RNN
823
      sequence_feature_map = {
824
          k: features[input_fn.SEQUENCE_KEY_PREFIX + k]
825
          for k in hparams.sequence_features
826
      }
827
      sequence_length = tf.squeeze(
828
          features[input_fn.CONTEXT_KEY_PREFIX + 'sequenceLength'],
829
          axis=1,
830
          name='sq_seq_len')
831
      tf.summary.scalar('sequence_length', tf.reduce_mean(sequence_length))
832
      diff_delta_time, obs_values, indicator = construct_input(
833
          sequence_feature_map, hparams.categorical_values,
834
          hparams.categorical_seq_feature, hparams.feature_value, mode,
835
          hparams.normalize, hparams.momentum, hparams.min_value,
836
          hparams.max_value, hparams.input_keep_prob)
837

838
      seq_mask = tf.expand_dims(
839
          tf.sequence_mask(sequence_length, dtype=tf.float32), axis=2)
840
      logits, weights = construct_logits(
841
          diff_delta_time,
842
          obs_values,
843
          indicator,
844
          sequence_length,
845
          seq_mask,
846
          hparams,
847
          reuse=False)
848

849
      all_attribution_dict = {}
850
      if mode == tf_estimator.ModeKeys.TRAIN:
851
        if hparams.sequence_prediction:
852
          assert not hparams.use_rnn_attention
853
          # If we train a sequence_prediction we repeat the labels over time.
854
          label_tensor = labels[hparams.label_key]
855
          labels[hparams.label_key] = tf.tile(
856
              tf.expand_dims(label_tensor, 2),
857
              multiples=[1, tf.shape(logits)[1], 1])
858
          if hparams.volatility_loss_factor > 0.0:
859
            volatility = tf.reduce_sum(
860
                tf.square(seq_mask *
861
                          compute_prediction_diff_attribution(logits)))
862
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
863
                                 volatility * hparams.volatility_loss_factor)
864
        elif not hparams.use_rnn_attention:
865
          logits = rnn_common.select_last_activations(
866
              logits, tf.to_int32(sequence_length))
867
      else:
868
        if hparams.sequence_prediction:
869
          last_logits = rnn_common.select_last_activations(
870
              logits, tf.to_int32(sequence_length))
871
        else:
872
          last_logits = logits
873
        if mode == tf_estimator.ModeKeys.PREDICT:
874
          delta_time = sequence_feature_map['deltaTime']
875
          all_attributions = {}
876
          if hparams.include_gradients_attribution:
877
            all_attributions['gradient_last'] = compute_gradient_attribution(
878
                last_logits, obs_values, indicator)
879
          if hparams.include_gradients_sum_time_attribution:
880
            assert not hparams.use_rnn_attention
881
            all_attributions['gradient_sum'] = compute_gradient_attribution(
882
                _predictions_for_gradients(
883
                    logits, seq_mask, delta_time,
884
                    hparams.attribution_max_delta_time, averaged=False),
885
                obs_values, indicator)
886
          if hparams.include_gradients_avg_time_attribution:
887
            assert not hparams.use_rnn_attention
888
            all_attributions['gradient_avg'] = compute_gradient_attribution(
889
                _predictions_for_gradients(
890
                    logits, seq_mask, delta_time,
891
                    hparams.attribution_max_delta_time, averaged=True),
892
                obs_values, indicator)
893
          if hparams.include_path_integrated_gradients_attribution:
894
            all_attributions['integrated_gradient'] = (
895
                compute_path_integrated_gradient_attribution(
896
                    obs_values, indicator, diff_delta_time, delta_time,
897
                    sequence_length, seq_mask, hparams))
898
          if hparams.use_rnn_attention:
899
            all_attributions['rnn_attention'] = weights
900
          if hparams.include_diff_sequence_prediction_attribution:
901
            all_attributions['diff_sequence'] = (
902
                compute_prediction_diff_attribution(logits))
903

904
          all_attribution_dict = {}
905
          for attribution_name, attribution in all_attributions.items():
906
            attribution_dict = convert_attribution(
907
                attribution,
908
                sequence_feature_map,
909
                seq_mask,
910
                delta_time,
911
                hparams.attribution_threshold,
912
                hparams.attribution_max_delta_time,
913
                prefix=attribution_name + '-')
914
            all_attribution_dict.update(attribution_dict)
915
          if hparams.include_sequence_prediction:
916
            # Add the predictions at each time step to the attention dictionary.
917
            attribution_indices = tf.where(seq_mask > 0.5)
918
            all_attribution_dict['predictions'] = tf.sparse.expand_dims(
919
                tf.SparseTensor(
920
                    indices=attribution_indices,
921
                    values=tf.gather_nd(
922
                        tf.sigmoid(logits), attribution_indices),
923
                    dense_shape=tf.to_int64(tf.shape(delta_time))),
924
                axis=1)
925
        # At test/inference time we only make a single prediction even if we did
926
        # sequence_prediction during training.
927
        logits = last_logits
928
        seq_mask = None
929

930
      probabilities = tf.sigmoid(logits)
931
      classes = probabilities > 0.5
932
      predictions = {
933
          PredictionKeys.LOGITS: logits,
934
          PredictionKeys.PROBABILITIES: probabilities,
935
          PredictionKeys.CLASSES: classes
936
      }
937
      # Calculate the loss for TRAIN and EVAL, but not PREDICT.
938
      if mode == tf_estimator.ModeKeys.PREDICT:
939
        loss = None
940
      else:
941
        loss = tf.nn.sigmoid_cross_entropy_with_logits(
942
            labels=labels[hparams.label_key],
943
            logits=predictions[PredictionKeys.LOGITS])
944
        if hparams.sequence_prediction:
945
          loss *= seq_mask
946
        loss = tf.reduce_mean(loss)
947
        regularization_losses = tf.losses.get_regularization_losses()
948
        if regularization_losses:
949
          tf.summary.scalar('loss/prior_regularization', loss)
950
          regularization_loss = tf.add_n(regularization_losses)
951
          tf.summary.scalar('loss/regularization_loss', regularization_loss)
952
          loss += regularization_loss
953
        tf.summary.scalar('loss', loss)
954

955
      train_op = None
956
      if mode == tf_estimator.ModeKeys.TRAIN:
957
        optimizer = tf.train.AdamOptimizer(
958
            learning_rate=hparams.learning_rate, beta1=0.9, beta2=0.999,
959
            epsilon=1e-8)
960
        optimizer = contrib_estimator.clip_gradients_by_norm(optimizer, 6.0)
961
        train_op = contrib_training.create_train_op(
962
            total_loss=loss, optimizer=optimizer, summarize_gradients=False)
963
      if mode != tf_estimator.ModeKeys.TRAIN:
964
        for k, v in all_attribution_dict.items():
965
          if not isinstance(v, tf.SparseTensor):
966
            raise ValueError('Expect attributions to be in SparseTensor, '
967
                             'getting %s for feature %s' %
968
                             (v.__class__.__name__, k))
969
          predictions['attention_attribution,%s,indices' % k] = v.indices
970
          predictions['attention_attribution,%s,values' % k] = v.values
971
          predictions['attention_attribution,%s,shape' % k] = v.dense_shape
972

973
      eval_metric_ops = {}
974
      if mode == tf_estimator.ModeKeys.EVAL:
975
        auc = tf.metrics.auc
976
        prob_k = PredictionKeys.PROBABILITIES
977
        class_k = PredictionKeys.CLASSES
978
        m = 'careful_interpolation'
979
        metric_fn_dict = {
980
            'auc-roc':
981
                lambda l, p: auc(l, p[prob_k], curve='ROC', summation_method=m),
982
            'auc-pr':
983
                lambda l, p: auc(l, p[prob_k], curve='PR', summation_method=m),
984
            'accuracy':
985
                lambda l, p: tf.metrics.accuracy(l, p[class_k]),
986
        }
987
        for (k, f) in metric_fn_dict.items():
988
          eval_metric_ops[k] = f(label_tensor, predictions)
989
      # Define the output for serving.
990
      export_outputs = {}
991
      if mode == tf_estimator.ModeKeys.PREDICT:
992
        export_outputs = {
993
            'mortality': tf_estimator.export.PredictOutput(predictions)
994
        }
995

996
      return tf_estimator.EstimatorSpec(
997
          mode=mode,
998
          predictions=predictions,
999
          loss=loss,
1000
          train_op=train_op,
1001
          eval_metric_ops=eval_metric_ops,
1002
          export_outputs=export_outputs)
1003

1004
    return model_fn
1005

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.