google-research
1004 строки · 45.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Observation based RNN model."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import tensorflow.compat.v1 as tf
23from tensorflow.compat.v1 import estimator as tf_estimator
24
25from explaining_risk_increase import input_fn
26from tensorflow.contrib import estimator as contrib_estimator
27from tensorflow.contrib import lookup as contrib_lookup
28from tensorflow.contrib import rnn as contrib_rnn
29from tensorflow.contrib import training as contrib_training
30from tensorflow.contrib.learn.python.learn.estimators import rnn_common
31
32
33TOLERANCE = 0.2
34
35
36class PredictionKeys(object):
37"""Enum for prediction keys."""
38LOGITS = 'logits'
39PROBABILITIES = 'probs'
40CLASSES = 'classes'
41
42
43def _most_recent_obs_value(obs_values, indicator, delta_time,
44attribution_max_delta_time):
45"""Returns the most recent lab result for each test within a time frame.
46
47The eligible lab values fall into a time window until time of prediction -
48attribution_max_delta_time. Among those we select their most recent value
49or zero if there are none.
50
51Args:
52obs_values: A dense representation of the observation_values at the position
53of their obs_code_ids. A padded Tensor of shape [batch_size,
54max_sequence_length, vocab_size] of type float32 where obs_values[b, t,
55id] = observation_values[b, t, 0] and id = observation_code_ids[b, t, 0]
56and obs_values[b, t, x] = 0 for all other x != id. If t is greater than
57the sequence_length of batch entry b then the result is 0 as well.
58indicator: A one-hot encoding of whether a value in obs_values comes from
59observation_values or is just filled in to be 0. A Tensor of shape
60[batch_size, max_sequence_length, vocab_size] and type float32.
61delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
62the time to prediction.
63attribution_max_delta_time: Time threshold so that we return the most recent
64lab values among those that are at least attribution_max_delta_time
65seconds old at time of prediction.
66
67Returns:
68A Tensor of shape [batch_size, 1, vocab_size] of the most recent lab results
69for all lab tests that are at least attribution_max_delta_time old at time
70of prediction.
71"""
72batch_size = tf.shape(indicator)[0]
73seq_len = tf.shape(indicator)[1]
74num_obs = indicator.shape[2]
75# Prepend a dummy so that for lab tests for which we have no eligible lab
76# values we will select 0.
77obs_values = tf.concat(
78[tf.zeros([batch_size, 1, num_obs]), obs_values], axis=1)
79indicator = tf.concat([tf.ones([batch_size, 1, num_obs]), indicator], axis=1)
80delta_time = tf.to_int32(delta_time)
81delta_time = tf.concat(
82[
83tf.zeros([batch_size, 1, 1], dtype=tf.int32) +
84attribution_max_delta_time, delta_time
85],
86axis=1)
87# First we figure out what the eligible lab values are that are at least
88# attribution_max_delta_time old.
89indicator = tf.to_int32(indicator)
90indicator *= tf.to_int32(delta_time >= attribution_max_delta_time)
91range_val = tf.expand_dims(tf.range(seq_len + 1), axis=0)
92range_val = tf.tile(range_val, multiples=[tf.shape(indicator)[0], 1])
93# [[[0], [1], ..., [max_sequence_length]],
94# [[0], [1], ..., [max_sequence_length]],
95# ...]
96range_val = tf.expand_dims(range_val, axis=2)
97# [batch_size, max_sequence_length, vocab_size] with 1 non-zero number per
98# time-step equal to that time-step.
99seq_indicator = indicator * range_val
100# [batch_size, vocab_size] with the time-step of the last lab value.
101last_val_indicator = tf.reduce_max(seq_indicator, axis=1, keepdims=True)
102last_val_indicator = tf.tile(
103last_val_indicator, multiples=[1, tf.shape(indicator)[1], 1])
104
105# eq indicates which lab values are the most recent ones.
106eq = tf.logical_and(
107tf.equal(last_val_indicator, seq_indicator), indicator > 0)
108most_recent_obs_value_indicator = tf.where(eq)
109# Collect the lab values associated with those indices.
110res = tf.gather_nd(obs_values, most_recent_obs_value_indicator)
111# Reorder the values by batch and then by lab test.
112res_sorted = tf.sparse_reorder(
113tf.sparse_transpose(
114tf.SparseTensor(
115indices=most_recent_obs_value_indicator,
116values=res,
117dense_shape=tf.to_int64(
118tf.stack([batch_size, seq_len + 1, num_obs]))),
119perm=[0, 2, 1])).values
120
121return tf.reshape(res_sorted, [batch_size, 1, num_obs])
122
123
124def _predictions_for_gradients(predictions, seq_mask, delta_time,
125attribution_max_delta_time, averaged):
126"""Aggregates eligible predictions over time.
127
128Predictions are eligible if their are within the sequence_length (as indicated
129by seq_mask) and their associated delta_time is at most
130attribution_max_delta_time.
131Eligible predictions are either averaged across those eligble times (if
132averaged=True) or summed otherwise.
133
134Args:
135predictions: A Tensor of shape [batch_size, max_seq_len, 1]
136with the predictions in the sequence.
137seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
138which timesteps are padded.
139delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
140the time to prediction.
141attribution_max_delta_time: Attribution is limited to values that are no
142older than that many seconds at time of prediction.
143averaged: Whether predictions are simply summed up across the time-steps
144or averaged over on the sequence length.
145Returns:
146A Tensor of shape [batch, 1, 1] of the eligible predictions
147aggregated across time.
148"""
149mask = seq_mask * tf.to_float(delta_time < attribution_max_delta_time)
150predictions *= mask
151if averaged:
152predictions /= tf.reduce_sum(mask, axis=1, keepdims=True)
153return tf.reduce_sum(predictions, axis=1, keepdims=True)
154
155
156def compute_gradient_attribution(predictions, obs_values, indicator):
157"""Constructs the attribution of what inputs result in a higher prediction.
158
159Attribution here refers to the timesteps in which the predictions (derived
160from the logits) increased. We are only interested in increases in the
161previous 12h.
162
163Args:
164predictions: A Tensor of shape [batch_size, 1, 1] with the
165predictions in the sequence.
166obs_values: A dense representation of the observation_values with
167obs_values[b, t, :] has at most one non-zero value at the position
168of the corresponding lab test from obs_code_ids with the value of the lab
169result. A padded Tensor of shape [batch_size, max_sequence_length,
170vocab_size] of type float32 of possibly normalized observation values.
171indicator: A one-hot encoding of whether a value in obs_values comes from
172observation_values or is just filled in to be 0. A Tensor of
173shape [batch_size, max_sequence_length, vocab_size] and type float32.
174Returns:
175A Tensor of shape [batch, max_sequence_length, 1] of the gradient of the
176prediction as a function of the lab result at that batch-entry time.
177"""
178attr = tf.gradients(tf.squeeze(predictions, axis=1,
179name='squeeze_pred_for_gradients'),
180[obs_values])[0]
181# Zero-out gradients for other lab-tests and then sum up across lab tests
182# for which at most one gradient will be non-zero.
183attr *= indicator
184attr = tf.reduce_sum(attr, axis=2, keepdims=True)
185return attr
186
187
188def compute_path_integrated_gradient_attribution(
189obs_values,
190indicator,
191diff_delta_time,
192delta_time,
193sequence_length,
194seq_mask,
195hparams,
196construct_logits_fn=None):
197"""Constructs the attribution of what inputs result in a higher prediction.
198
199Attribution here refers to the integrated gradients as defined here
200https://arxiv.org/pdf/1703.01365.pdf and approximated for the j-th variable
201via
202
203(x-x') * 1/num_steps * sum_{i=1}^{num_steps} of the derivative of
204F(x'+(x-x')*i/num_steps) w.r.t. its j-th input.
205
206where we take x' the most recent value before attribution_max_delta_time and
207x to be the subsequent observation values from the same lab test.
208x'+(x-x')*i/num_steps is the linear interpolation between x' and x.
209
210Args:
211obs_values: A dense representation of the observation_values with
212obs_values[b, t, :] has at most one non-zero value at the position
213of the corresponding lab test from obs_code_ids with the value of the lab
214result. A padded Tensor of shape [batch_size, max_sequence_length,
215vocab_size] of type float32 of possibly normalized observation values.
216indicator: A one-hot encoding of whether a value in obs_values comes from
217observation_values or is just filled in to be 0. A Tensor of
218shape [batch_size, max_sequence_length, vocab_size] and type float32.
219diff_delta_time: Difference between two consecutive time steps.
220delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
221the time to prediction.
222sequence_length: Sequence length (before padding), Tensor of shape
223[batch_size].
224seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1]
225indicating which timesteps are padded.
226hparams: Hyper parameters.
227construct_logits_fn: A method with constructing the logits given input as
228construct_logits. If None using construct_logits.
229Returns:
230A Tensor of shape [batch, max_sequence_length, 1] of the gradient of the
231prediction as a function of the lab result at that batch-entry time.
232"""
233last_obs_values_0 = _most_recent_obs_value(obs_values, indicator, delta_time,
234hparams.attribution_max_delta_time)
235gradients = []
236# We need to limit the diff over the base to timesteps after base.
237last_obs_values = last_obs_values_0 * (
238tf.to_float(indicator) *
239tf.to_float(delta_time < hparams.attribution_max_delta_time))
240obs_values_with_last_replaced = obs_values * tf.to_float(
241delta_time >= hparams.attribution_max_delta_time) + last_obs_values
242diff_over_base = obs_values - obs_values_with_last_replaced
243
244for i in range(hparams.path_integrated_gradients_num_steps):
245alpha = 1.0 * i / (hparams.path_integrated_gradients_num_steps - 1)
246step_obs_values = obs_values_with_last_replaced + diff_over_base * alpha
247if not construct_logits_fn:
248construct_logits_fn = construct_logits
249logits, _ = construct_logits_fn(
250diff_delta_time,
251step_obs_values,
252indicator,
253sequence_length,
254seq_mask,
255hparams,
256reuse=True)
257if hparams.use_rnn_attention:
258last_logits = logits
259else:
260last_logits = rnn_common.select_last_activations(
261logits, tf.to_int32(sequence_length))
262# Ideally, we'd like to get the gradients of the change in
263# value over the previous one to attribute it to both and not just a single
264# value.
265gradient = compute_gradient_attribution(last_logits, step_obs_values,
266indicator)
267gradients.append(
268tf.reduce_sum(diff_over_base, axis=2, keepdims=True) * gradient)
269return tf.add_n(gradients) / tf.to_float(
270hparams.path_integrated_gradients_num_steps)
271
272
273def compute_attention(seq_output, last_output, hidden_layer_dim, seq_mask,
274sequence_length):
275"""Constructs attention of the last_output as query and the sequence output.
276
277The attention is the dot-product of the last_output (the final RNN output),
278with the seq_output (the RNN's output at each step). Here the final RNN output
279is considered as the "query" or "context" vector. The final attention output
280is a weighted sum of the RNN's outputs at all steps. Details:
281
282alpha_i = seq_output_i * last_output
283beta is then obtained by normalizing alpha:
284beta_i = exp(alpha_i) / sum_j exp(alpha_j)
285The new attention vector is then the beta-weighted sum over the seq_output:
286attention_vector = sum_i beta_i * seq_output_i
287
288If hidden_dim > 0 then before computing alpha the seq_output and the
289last_output are sent through two separate hidden layers.
290seq_output = hidden_layer(seq_output)
291last_output = hidden_layer(last_output)
292
293Args:
294seq_output: The raw rnn output of shape [batch_size, max_sequence_length,
295rnn_size].
296last_output: The last output of the rnn of shape [batch_size, rnn_size].
297hidden_layer_dim: If 0 no hidden layer is applied before multiplying the
298last_logits with the seq_logits.
299seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
300which timesteps are padded.
301sequence_length: Sequence length (before padding), Tensor of shape
302[batch_size].
303
304Returns:
305Attention output with shape [batch_size, rnn_size].
306The attention beta tensor.
307"""
308# Compute the weights.
309if hidden_layer_dim > 0:
310last_output = tf.layers.dense(
311last_output, hidden_layer_dim, activation=tf.nn.relu6)
312seq_output = tf.layers.dense(
313seq_output, hidden_layer_dim, activation=tf.nn.relu6)
314last_output = tf.expand_dims(last_output, 1) # [batch_size, 1, rnn_size]
315tmp = tf.multiply(seq_output, last_output) # dim 1: broadcast
316alpha_tensor = tf.reduce_sum(tmp, 2) # [b, max_seq_len]
317alpha_tensor *= tf.squeeze(seq_mask, axis=2)
318beta_tensor = tf.nn.softmax(alpha_tensor) # using default dim -1
319beta_tensor = tf.expand_dims(beta_tensor, -1) # [b, max_seq_len, 1]
320
321# Compute weighted sum of the original rnn_outputs over all steps
322tmp = seq_output * beta_tensor # last dim: use "broadcast"
323rnn_outputs_weighted_sum = tf.reduce_sum(tmp, 1) # [b, rnn_size]
324last_beta = rnn_common.select_last_activations(
325beta_tensor, tf.to_int32(sequence_length))
326tf.summary.histogram('last_beta_attention', last_beta)
327
328return rnn_outputs_weighted_sum, beta_tensor
329
330
331def compute_prediction_diff_attribution(logits):
332"""Constructs the attribution of what inputs result in a higher prediction.
333
334Attribution here refers to the timesteps in which the predictions (derived
335from the logits) increased.
336
337Args:
338logits: The logits of the model_fn.
339Returns:
340A Tensor of shape [batch_size, max_sequence_length, 1] with an attribution
341value at time t of prediction at time t minus prediction at time t-1.
342"""
343predictions = tf.sigmoid(logits)
344shape = tf.shape(logits)
345zeros = tf.zeros(shape=[shape[0], 1, shape[2]], dtype=tf.float32)
346# Our basic notion of attribution at timestep i is how much the predicted
347# risk increased at that time compared to the previous prediction.
348return predictions - tf.concat(
349[zeros, predictions[:, :-1, :]], axis=1, name='attribution')
350
351
352def convert_attribution(attribution, sequence_feature_map, seq_mask, delta_time,
353attribution_threshold, attribution_max_delta_time,
354prefix=''):
355"""Constructs the attribution of what inputs result in a higher prediction.
356
357Attribution here refers to the timesteps in which the predictions (derived
358from the logits) increased. We are only interested in increases in the
359previous attribution_max_delta_time.
360
361Args:
362attribution: A Tensor of shape [batch, max_sequence_length, 1] computed
363using some attribution method.
364sequence_feature_map: A dictionary from name to (Sparse)Tensor.
365seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
366which timesteps are padded.
367delta_time: A Tensor of shape [batch_size, max_sequence_length] describing
368the time to prediction.
369attribution_threshold: Attribution values below this threshold will be
370dropped.
371attribution_max_delta_time: Attribution is limited to values that are no
372older than that many seconds at time of prediction.
373prefix: A string to prepend to the feature names for the attribution_dict.
374Returns:
375A dictionary from feature names to SparseTensors of
376dense_shape [batch_size, max_sequence_length, 1].
377"""
378# We do not want attribution in the padding.
379attribution *= seq_mask
380
381# We focus on attribution in the past 12h.
382# [batch_size, max_sequence_length, 1]
383attribution *= tf.to_float(delta_time < attribution_max_delta_time)
384
385# We get rid of low attribution.
386attribution_indices = tf.where(attribution > attribution_threshold)
387attribution_values = tf.gather_nd(attribution, attribution_indices)
388
389# Now, attribution.indices indicate in the input timesteps which we should
390# attend to.
391attribution_dict = {}
392for feature, sp_feature in sequence_feature_map.items():
393# Limitation: This is not going to work for sequence feature in which
394# the third (last/token) dimension is > 1. In that case only the first
395# token would be highlighted.
396attribution_dict[prefix + feature] = tf.sparse.expand_dims(
397tf.SparseTensor(
398indices=attribution_indices,
399values=attribution_values,
400dense_shape=tf.to_int64(tf.shape(sp_feature))), axis=1)
401return attribution_dict
402
403
404def normalize_each_feature(observation_values, obs_code, vocab_size, mode,
405momentum):
406"""Combines SparseTensors of observation codes and values into a Tensor.
407
408Args:
409observation_values: A SparseTensor of type float with the observation
410values of dense shape [batch_size, max_sequence_length, 1].
411There may be no time gaps in between codes.
412obs_code: A Tensor of shape [?, 3] of type int32 with the ids that go along
413with the observation_values. We will do the normalization separately for
414each lab test.
415vocab_size: The range of the values in obs_code is from 0 to vocab_size.
416mode: The execution mode, as defined in tf.estimator.ModeKeys.
417momentum: Mean and variance will be updated as
418momentum*old_value + (1-momentum) * new_value.
419Returns:
420observation_values as in the input only with normalized values.
421"""
422with tf.variable_scope('batch_normalization'):
423new_indices = []
424new_values = []
425
426for i in range(vocab_size):
427with tf.variable_scope('bn' + str(i)):
428positions_of_feature_i = tf.where(tf.equal(obs_code, i))
429values_of_feature_i = tf.gather_nd(observation_values.values,
430positions_of_feature_i)
431if mode == tf_estimator.ModeKeys.TRAIN:
432tf.summary.scalar('avg_observation_values/' + str(i),
433tf.reduce_mean(values_of_feature_i))
434tf.summary.histogram('observation_values/' + str(i),
435values_of_feature_i)
436batchnorm_layer = tf.layers.BatchNormalization(
437axis=1,
438momentum=momentum,
439epsilon=0.01,
440trainable=True)
441normalized_values = tf.squeeze(
442batchnorm_layer.apply(
443tf.expand_dims(values_of_feature_i, axis=1),
444training=(mode == tf_estimator.ModeKeys.TRAIN)
445),
446axis=1,
447name='squeeze_normalized_values')
448if mode == tf_estimator.ModeKeys.TRAIN:
449tf.summary.scalar('batchnorm_layer/moving_mean/' + str(i),
450tf.squeeze(batchnorm_layer.moving_mean))
451tf.summary.scalar('batchnorm_layer/moving_variance/' + str(i),
452tf.squeeze(batchnorm_layer.moving_variance))
453tf.summary.scalar('avg_normalized_values/' + str(i),
454tf.reduce_mean(normalized_values))
455tf.summary.histogram('normalized_observation_values/' + str(i),
456normalized_values)
457indices_i = tf.gather_nd(observation_values.indices,
458positions_of_feature_i)
459new_indices += [indices_i]
460normalized_values = tf.where(tf.is_nan(normalized_values),
461tf.zeros_like(normalized_values),
462normalized_values)
463new_values += [normalized_values]
464
465normalized_sp_tensor = tf.SparseTensor(
466indices=tf.concat(new_indices, axis=0),
467values=tf.concat(new_values, axis=0),
468dense_shape=observation_values.dense_shape)
469normalized_sp_tensor = tf.sparse_reorder(normalized_sp_tensor)
470return normalized_sp_tensor
471
472
473def combine_observation_code_and_values(observation_code_ids,
474observation_values, vocab_size, mode,
475normalize, momentum, min_value,
476max_value):
477"""Combines SparseTensors of observation codes and values into a Tensor.
478
479Args:
480observation_code_ids: A SparseTensor of type int32 with the ids of the
481observation codes of dense shape [batch_size, max_sequence_length, 1].
482There may be no time gaps in between codes.
483observation_values: A SparseTensor of type float with the observation
484values of dense shape [batch_size, max_sequence_length, 1].
485There may be no time gaps in between codes.
486vocab_size: The range of the values in obs_code_ids is from 0 to vocab_size.
487mode: The execution mode, as defined in tf.estimator.ModeKeys.
488normalize: Whether to normalize each lab test.
489momentum: For the batch normalization mean and variance will be updated as
490momentum*old_value + (1-momentum) * new_value.
491min_value: Observation values smaller than this will be capped to min_value.
492max_value: Observation values larger than this will be capped to max_value.
493
494Returns:
495- obs_values: A dense representation of the observation_values at the
496position of their obs_code_ids. A padded Tensor of shape
497[batch_size, max_sequence_length, vocab_size] of type float32
498where obs_values[b, t, id] = observation_values[b, t, 0] and
499id = observation_code_ids[b, t, 0] and obs_values[b, t, x] = 0
500for all other x != id. If t is greater than the
501sequence_length of batch entry b then the result is 0 as well.
502- indicator: A one-hot encoding of whether a value in obs_values comes from
503observation_values or is just filled in to be 0. A Tensor of
504shape [batch_size, max_sequence_length, vocab_size] and type
505float32.
506"""
507obs_code = observation_code_ids.values
508if normalize:
509with tf.variable_scope('values'):
510observation_values = normalize_each_feature(
511observation_values, obs_code, vocab_size, mode, momentum)
512observation_values_rank2 = tf.SparseTensor(
513values=observation_values.values,
514indices=observation_values.indices[:, 0:2],
515dense_shape=observation_values.dense_shape[0:2])
516obs_indices = tf.concat(
517[observation_values_rank2.indices,
518tf.expand_dims(obs_code, axis=1)],
519axis=1, name='obs_indices')
520obs_shape = tf.concat(
521[observation_values_rank2.dense_shape, [vocab_size]], axis=0,
522name='obs_shape')
523
524obs_values = tf.sparse_to_dense(obs_indices, obs_shape,
525observation_values_rank2.values)
526obs_values.set_shape([None, None, vocab_size])
527indicator = tf.sparse_to_dense(obs_indices, obs_shape,
528tf.ones_like(observation_values_rank2.values))
529indicator.set_shape([None, None, vocab_size])
530# clip
531obs_values = tf.minimum(obs_values, max_value)
532obs_values = tf.maximum(obs_values, min_value)
533return obs_values, indicator
534
535
536def construct_input(sequence_feature_map, categorical_values,
537categorical_seq_feature, feature_value, mode, normalize,
538momentum, min_value, max_value, input_keep_prob):
539"""Returns a function to build the model.
540
541Args:
542sequence_feature_map: A dictionary of (Sparse)Tensors of dense shape
543[batch_size, max_sequence_length, None] keyed by the feature name.
544categorical_values: Potential values of the categorical_seq_feature.
545categorical_seq_feature: Name of feature of observation code.
546feature_value: Name of feature of observation value.
547mode: The execution mode, as defined in tf.estimator.ModeKeys.
548normalize: Whether to normalize each lab test.
549momentum: For the batch normalization mean and variance will be updated as
550momentum*old_value + (1-momentum) * new_value.
551min_value: Observation values smaller than this will be capped to min_value.
552max_value: Observation values larger than this will be capped to max_value.
553input_keep_prob: Keep probability for input observation values.
554
555Returns:
556- diff_delta_time: Tensor of shape [batch_size, max_seq_length, 1]
557with the
558- obs_values: A dense representation of the observation_values with
559obs_values[b, t, :] has at most one non-zero value at the
560position of the corresponding lab test from obs_code_ids with
561the value of the lab result. A padded Tensor of shape
562[batch_size, max_sequence_length, vocab_size] of type float32
563of possibly normalized observation values.
564- indicator: A one-hot encoding of whether a value in obs_values comes from
565observation_values or is just filled in to be 0. A Tensor of
566shape [batch_size, max_sequence_length, vocab_size] and type
567float32.
568"""
569with tf.variable_scope('input'):
570sequence_feature_map = {
571k: tf.sparse_reorder(s) if isinstance(s, tf.SparseTensor) else s
572for k, s in sequence_feature_map.items()
573}
574# Filter out invalid values.
575# For invalid observation values we do this through a sparse retain.
576# This makes sure that the invalid values will not be considered in the
577# normalization.
578observation_values = sequence_feature_map[feature_value]
579observation_code_sparse = sequence_feature_map[categorical_seq_feature]
580# Future work: Create a flag for the missing value indicator.
581valid_values = tf.abs(observation_values.values - 9999999.0) > TOLERANCE
582# apply input dropout
583if input_keep_prob < 1.0:
584random_tensor = input_keep_prob
585random_tensor += tf.random_uniform(tf.shape(observation_values.values))
586# 0. if [input_keep_prob, 1.0) and 1. if [1.0, 1.0 + input_keep_prob)
587dropout_mask = tf.floor(random_tensor)
588if mode == tf_estimator.ModeKeys.TRAIN:
589valid_values = tf.to_float(valid_values) * dropout_mask
590valid_values = valid_values > 0.5
591sequence_feature_map[feature_value] = tf.sparse_retain(
592observation_values, valid_values)
593sequence_feature_map[categorical_seq_feature] = tf.sparse_retain(
594observation_code_sparse, valid_values)
595
596# 1. Construct the sequence of observation values to feed into the RNN
597# and their indicator.
598# We assign each observation code an id from 0 to vocab_size-1. At each
599# timestep we will lookup the id for the observation code and take the value
600# of the lab test and a construct a vector with all zeros but the id-th
601# position is set to the lab test value.
602obs_code = sequence_feature_map[categorical_seq_feature]
603obs_code_dense_ids = contrib_lookup.index_table_from_tensor(
604tuple(categorical_values), num_oov_buckets=0,
605name='vocab_lookup').lookup(obs_code.values)
606obs_code_sparse = tf.SparseTensor(
607values=obs_code_dense_ids,
608indices=obs_code.indices,
609dense_shape=obs_code.dense_shape)
610obs_code_sparse = tf.sparse_reorder(obs_code_sparse)
611observation_values = sequence_feature_map[feature_value]
612observation_values = tf.sparse_reorder(observation_values)
613vocab_size = len(categorical_values)
614obs_values, indicator = combine_observation_code_and_values(
615obs_code_sparse, observation_values, vocab_size, mode, normalize,
616momentum, min_value, max_value)
617
618# 2. We compute the diff_delta_time as additional sequence feature.
619# Note, the LSTM is very sensitive to how you encode time.
620delta_time = sequence_feature_map['deltaTime']
621diff_delta_time = tf.concat(
622[delta_time[:, :1, :], delta_time[:, :-1, :]], axis=1) - delta_time
623diff_delta_time = tf.to_float(diff_delta_time) / (60.0 * 60.0)
624
625return (diff_delta_time, obs_values, indicator)
626
627
628def construct_rnn_logits(diff_delta_time,
629obs_values,
630indicator,
631sequence_length,
632rnn_size,
633variational_recurrent_keep_prob,
634variational_input_keep_prob,
635variational_output_keep_prob,
636reuse=False):
637"""Computes logits combining inputs and applying an RNN.
638
639Args:
640diff_delta_time: Difference between two consecutive time steps.
641obs_values: A dense representation of the observation_values with
642obs_values[b, t, :] has at most one non-zero value at the position
643of the corresponding lab test from obs_code_ids with the value of the lab
644result. A padded Tensor of shape [batch_size, max_sequence_length,
645vocab_size] of type float32 of possibly normalized observation values.
646indicator: A one-hot encoding of whether a value in obs_values comes from
647observation_values or is just filled in to be 0. A Tensor of
648shape [batch_size, max_sequence_length, vocab_size] and type float32.
649sequence_length: Sequence length (before padding), Tensor of shape
650[batch_size].
651rnn_size: Size of the LSTM hidden state and output.
652variational_recurrent_keep_prob: 1 - droput for the hidden LSTM state.
653variational_input_keep_prob: 1 - dropout for the input to the LSTM.
654variational_output_keep_prob: 1 - dropout for the output of the LSTM.
655reuse: Whether to reuse existing variables or setup new ones.
656
657Returns:
658logits a Tensor of shape [batch_size, max_sequence_length, 1].
659"""
660with tf.variable_scope('logits/rnn', reuse=reuse) as sc:
661rnn_inputs = [diff_delta_time, indicator, obs_values]
662sequence_data = tf.concat(rnn_inputs, axis=2, name='rnn_input')
663
664# Run a recurrent neural network across the time dimension.
665cell = contrib_rnn.LSTMCell(rnn_size, state_is_tuple=True)
666if (variational_recurrent_keep_prob < 1 or variational_input_keep_prob < 1
667or variational_output_keep_prob < 1):
668cell = tf.nn.rnn_cell.DropoutWrapper(
669cell, input_keep_prob=variational_input_keep_prob,
670output_keep_prob=variational_output_keep_prob,
671state_keep_prob=variational_recurrent_keep_prob,
672variational_recurrent=True, input_size=tf.shape(sequence_data)[2],
673seed=12345678)
674
675output, _ = tf.nn.dynamic_rnn(
676cell,
677sequence_data,
678sequence_length=sequence_length,
679dtype=tf.float32,
680swap_memory=True,
681scope='rnn')
682
683# 3. Make a time-series of logits via a linear-mapping of the rnn-output
684# to logits_dimension = 1.
685return tf.layers.dense(
686output, 1, name=sc, reuse=reuse, activation=None), output
687
688
689def construct_logits(diff_delta_time, obs_values, indicator, sequence_length,
690seq_mask, hparams, reuse):
691"""Constructs logits through an RNN.
692
693Args:
694diff_delta_time: Difference between two consecutive time steps.
695obs_values: A dense representation of the observation_values with
696obs_values[b, t, :] has at most one non-zero value at the position
697of the corresponding lab test from obs_code_ids with the value of the lab
698result. A padded Tensor of shape [batch_size, max_sequence_length,
699vocab_size] of type float32 of possibly normalized observation values.
700indicator: A one-hot encoding of whether a value in obs_values comes from
701observation_values or is just filled in to be 0. A Tensor of
702shape [batch_size, max_sequence_length, vocab_size] and type float32.
703sequence_length: Sequence length (before padding), Tensor of shape
704[batch_size].
705seq_mask: A Tensor of shape [batch_size, max_sequence_length, 1] indicating
706which timesteps are padded.
707hparams: Hyper parameters.
708reuse: Boolean indicator of whether to re-use the variables.
709
710Returns:
711- Logits: A Tensor of shape [batch, {max_sequence_length,1}, 1].
712- Weights: Defaults to None. Only populated to a Tensor of shape
713[batch, max_sequence_length, 1] if
714hparams.use_rnn_attention is True.
715"""
716
717logits, raw_output = construct_rnn_logits(
718diff_delta_time, obs_values, indicator, sequence_length, hparams.rnn_size,
719hparams.variational_recurrent_keep_prob,
720hparams.variational_input_keep_prob, hparams.variational_output_keep_prob,
721reuse)
722if hparams.use_rnn_attention:
723with tf.variable_scope('logits/rnn/attention', reuse=reuse) as sc:
724last_logits = rnn_common.select_last_activations(
725raw_output, tf.to_int32(sequence_length))
726weighted_final_output, weight = compute_attention(
727raw_output, last_logits, hparams.attention_hidden_layer_dim,
728seq_mask, sequence_length)
729return tf.layers.dense(
730weighted_final_output, 1, name=sc, reuse=reuse,
731activation=None), weight
732else:
733return logits, None
734
735
736class ObservationSequenceModel(object):
737"""Model that runs an RNN over the time series of observation values.
738
739Consider a single lab (e.g. heart rate) and its value (e.g. 60) at a time.
740The input to the RNN at that timestep will have a value of 60 at the unique
741position for heart rate. The positions of all other lab tests will be 0.
742
743Additional input to the RNN include an indicator (to be able to distinguish a
744true lab measurement of 0 from the padded ones) and a notion of time
745(in particular how many hours have passed since the previous time-step).
746
747Caution: This model can only be run on condensed SequenceExample with an
748observation present each time step.
749"""
750
751def create_model_hparams(self):
752"""Returns default hparams for observation sequence model."""
753categorical_values_str = 'loinc:2823-3,loinc:2160-0,loinc:804-5,loinc:3094-0,loinc:786-4,loinc:2075-0,loinc:2951-2,loinc:34728-6,mimic3:observation_code:834,mimic3:observation_code:678,loinc:2345-7,mimic3:observation_code:3603,mimic3:observation_code:223761,loinc:3173-2,loinc:5895-7,loinc:5902-2,loinc:2601-3,loinc:2000-8,loinc:2777-1,mimic3:observation_code:3655,loinc:32693-4,mimic3:observation_code:679,mimic3:observation_code:676,loinc:2339-0,loinc:1994-3,mimic3:observation_code:224690,loinc:1975-2,loinc:1742-6,loinc:1920-8,loinc:6768-6,mimic3:observation_code:3312,mimic3:observation_code:8502,mimic3:observation_code:3313,loinc:1751-7,loinc:6598-7,mimic3:observation_code:225309,mimic3:observation_code:225310,mimic3:observation_code:40069,loinc:3016-3,loinc:1968-7,loinc:4548-4,loinc:2093-3,loinc:2085-9,loinc:2090-9,mimic3:observation_code:6701,mimic3:observation_code:8555,mimic3:observation_code:6702,loinc:10839-9,mimic3:observation_code:3318,mimic3:observation_code:3319'
754return contrib_training.HParams(
755context_features=['sequenceLength'],
756batch_size=128,
757learning_rate=0.002,
758sequence_features=[
759'deltaTime',
760'Observation.code',
761'Observation.valueQuantity.value',
762'Observation.valueQuantity.unit',
763'Observation.code.harmonized:valueset-observation-name',
764],
765feature_value='Observation.valueQuantity.value',
766categorical_values=categorical_values_str.split(','),
767categorical_seq_feature='Observation.code',
768label_key='label.in_hospital_death',
769input_keep_prob=1.0,
770attribution_threshold=0.0001,
771attribution_max_delta_time=12 * 60 * 60,
772rnn_size=64,
773variational_recurrent_keep_prob=0.99,
774variational_input_keep_prob=0.97,
775variational_output_keep_prob=0.98,
776sequence_prediction=False,
777normalize=True,
778momentum=0.75,
779min_value=-1000.0,
780max_value=1000.0,
781# If sequence_prediction is True then the loss will also include the
782# sum of the changes in predictions across the sequence as a way to
783# learn models with less volatile predictions.
784volatility_loss_factor=0.0,
785include_sequence_prediction=True,
786include_gradients_attribution=True,
787include_gradients_sum_time_attribution=False,
788include_gradients_avg_time_attribution=False,
789include_path_integrated_gradients_attribution=True,
790use_rnn_attention=False,
791attention_hidden_layer_dim=0,
792include_diff_sequence_prediction_attribution=True,
793# If include_path_integrated_gradients_attribution determines the number
794# of steps between the old and the current observation value.
795path_integrated_gradients_num_steps=10,
796)
797
798def create_model_fn(self, hparams):
799"""Returns a function to build the model.
800
801Args:
802hparams: The hyperparameters.
803
804Returns:
805A function to build the model's graph. This function is called by
806the Estimator object to construct the graph.
807"""
808
809def model_fn(features, labels, mode):
810"""Creates the prediction, loss, and train ops.
811
812Args:
813features: A dictionary of tensors keyed by the feature name.
814labels: A dictionary of label tensors keyed by the label key.
815mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.
816
817Returns:
818EstimatorSpec with the mode, prediction, loss, train_op and
819output_alternatives a dictionary specifying the output for a
820servo request during serving.
821"""
822# 1. Construct input to RNN
823sequence_feature_map = {
824k: features[input_fn.SEQUENCE_KEY_PREFIX + k]
825for k in hparams.sequence_features
826}
827sequence_length = tf.squeeze(
828features[input_fn.CONTEXT_KEY_PREFIX + 'sequenceLength'],
829axis=1,
830name='sq_seq_len')
831tf.summary.scalar('sequence_length', tf.reduce_mean(sequence_length))
832diff_delta_time, obs_values, indicator = construct_input(
833sequence_feature_map, hparams.categorical_values,
834hparams.categorical_seq_feature, hparams.feature_value, mode,
835hparams.normalize, hparams.momentum, hparams.min_value,
836hparams.max_value, hparams.input_keep_prob)
837
838seq_mask = tf.expand_dims(
839tf.sequence_mask(sequence_length, dtype=tf.float32), axis=2)
840logits, weights = construct_logits(
841diff_delta_time,
842obs_values,
843indicator,
844sequence_length,
845seq_mask,
846hparams,
847reuse=False)
848
849all_attribution_dict = {}
850if mode == tf_estimator.ModeKeys.TRAIN:
851if hparams.sequence_prediction:
852assert not hparams.use_rnn_attention
853# If we train a sequence_prediction we repeat the labels over time.
854label_tensor = labels[hparams.label_key]
855labels[hparams.label_key] = tf.tile(
856tf.expand_dims(label_tensor, 2),
857multiples=[1, tf.shape(logits)[1], 1])
858if hparams.volatility_loss_factor > 0.0:
859volatility = tf.reduce_sum(
860tf.square(seq_mask *
861compute_prediction_diff_attribution(logits)))
862tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
863volatility * hparams.volatility_loss_factor)
864elif not hparams.use_rnn_attention:
865logits = rnn_common.select_last_activations(
866logits, tf.to_int32(sequence_length))
867else:
868if hparams.sequence_prediction:
869last_logits = rnn_common.select_last_activations(
870logits, tf.to_int32(sequence_length))
871else:
872last_logits = logits
873if mode == tf_estimator.ModeKeys.PREDICT:
874delta_time = sequence_feature_map['deltaTime']
875all_attributions = {}
876if hparams.include_gradients_attribution:
877all_attributions['gradient_last'] = compute_gradient_attribution(
878last_logits, obs_values, indicator)
879if hparams.include_gradients_sum_time_attribution:
880assert not hparams.use_rnn_attention
881all_attributions['gradient_sum'] = compute_gradient_attribution(
882_predictions_for_gradients(
883logits, seq_mask, delta_time,
884hparams.attribution_max_delta_time, averaged=False),
885obs_values, indicator)
886if hparams.include_gradients_avg_time_attribution:
887assert not hparams.use_rnn_attention
888all_attributions['gradient_avg'] = compute_gradient_attribution(
889_predictions_for_gradients(
890logits, seq_mask, delta_time,
891hparams.attribution_max_delta_time, averaged=True),
892obs_values, indicator)
893if hparams.include_path_integrated_gradients_attribution:
894all_attributions['integrated_gradient'] = (
895compute_path_integrated_gradient_attribution(
896obs_values, indicator, diff_delta_time, delta_time,
897sequence_length, seq_mask, hparams))
898if hparams.use_rnn_attention:
899all_attributions['rnn_attention'] = weights
900if hparams.include_diff_sequence_prediction_attribution:
901all_attributions['diff_sequence'] = (
902compute_prediction_diff_attribution(logits))
903
904all_attribution_dict = {}
905for attribution_name, attribution in all_attributions.items():
906attribution_dict = convert_attribution(
907attribution,
908sequence_feature_map,
909seq_mask,
910delta_time,
911hparams.attribution_threshold,
912hparams.attribution_max_delta_time,
913prefix=attribution_name + '-')
914all_attribution_dict.update(attribution_dict)
915if hparams.include_sequence_prediction:
916# Add the predictions at each time step to the attention dictionary.
917attribution_indices = tf.where(seq_mask > 0.5)
918all_attribution_dict['predictions'] = tf.sparse.expand_dims(
919tf.SparseTensor(
920indices=attribution_indices,
921values=tf.gather_nd(
922tf.sigmoid(logits), attribution_indices),
923dense_shape=tf.to_int64(tf.shape(delta_time))),
924axis=1)
925# At test/inference time we only make a single prediction even if we did
926# sequence_prediction during training.
927logits = last_logits
928seq_mask = None
929
930probabilities = tf.sigmoid(logits)
931classes = probabilities > 0.5
932predictions = {
933PredictionKeys.LOGITS: logits,
934PredictionKeys.PROBABILITIES: probabilities,
935PredictionKeys.CLASSES: classes
936}
937# Calculate the loss for TRAIN and EVAL, but not PREDICT.
938if mode == tf_estimator.ModeKeys.PREDICT:
939loss = None
940else:
941loss = tf.nn.sigmoid_cross_entropy_with_logits(
942labels=labels[hparams.label_key],
943logits=predictions[PredictionKeys.LOGITS])
944if hparams.sequence_prediction:
945loss *= seq_mask
946loss = tf.reduce_mean(loss)
947regularization_losses = tf.losses.get_regularization_losses()
948if regularization_losses:
949tf.summary.scalar('loss/prior_regularization', loss)
950regularization_loss = tf.add_n(regularization_losses)
951tf.summary.scalar('loss/regularization_loss', regularization_loss)
952loss += regularization_loss
953tf.summary.scalar('loss', loss)
954
955train_op = None
956if mode == tf_estimator.ModeKeys.TRAIN:
957optimizer = tf.train.AdamOptimizer(
958learning_rate=hparams.learning_rate, beta1=0.9, beta2=0.999,
959epsilon=1e-8)
960optimizer = contrib_estimator.clip_gradients_by_norm(optimizer, 6.0)
961train_op = contrib_training.create_train_op(
962total_loss=loss, optimizer=optimizer, summarize_gradients=False)
963if mode != tf_estimator.ModeKeys.TRAIN:
964for k, v in all_attribution_dict.items():
965if not isinstance(v, tf.SparseTensor):
966raise ValueError('Expect attributions to be in SparseTensor, '
967'getting %s for feature %s' %
968(v.__class__.__name__, k))
969predictions['attention_attribution,%s,indices' % k] = v.indices
970predictions['attention_attribution,%s,values' % k] = v.values
971predictions['attention_attribution,%s,shape' % k] = v.dense_shape
972
973eval_metric_ops = {}
974if mode == tf_estimator.ModeKeys.EVAL:
975auc = tf.metrics.auc
976prob_k = PredictionKeys.PROBABILITIES
977class_k = PredictionKeys.CLASSES
978m = 'careful_interpolation'
979metric_fn_dict = {
980'auc-roc':
981lambda l, p: auc(l, p[prob_k], curve='ROC', summation_method=m),
982'auc-pr':
983lambda l, p: auc(l, p[prob_k], curve='PR', summation_method=m),
984'accuracy':
985lambda l, p: tf.metrics.accuracy(l, p[class_k]),
986}
987for (k, f) in metric_fn_dict.items():
988eval_metric_ops[k] = f(label_tensor, predictions)
989# Define the output for serving.
990export_outputs = {}
991if mode == tf_estimator.ModeKeys.PREDICT:
992export_outputs = {
993'mortality': tf_estimator.export.PredictOutput(predictions)
994}
995
996return tf_estimator.EstimatorSpec(
997mode=mode,
998predictions=predictions,
999loss=loss,
1000train_op=train_op,
1001eval_metric_ops=eval_metric_ops,
1002export_outputs=export_outputs)
1003
1004return model_fn
1005