google-research
492 строки · 19.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17"""Tests for observation sequence model."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import math
24
25from absl.testing import absltest
26from absl.testing import parameterized
27import numpy as np
28
29from six.moves import range
30import tensorflow.compat.v1 as tf
31from tensorflow.compat.v1 import estimator as tf_estimator
32from explaining_risk_increase import input_fn
33from explaining_risk_increase import observation_sequence_model as osm
34from tensorflow.contrib import training as contrib_training
35
36
37class ObservationSequenceTest(tf.test.TestCase, parameterized.TestCase):
38
39def setUp(self):
40super(ObservationSequenceTest, self).setUp()
41self.observation_values = tf.SparseTensor(
42indices=[[0, 0, 0], [0, 1, 0],
43[1, 0, 0], [1, 1, 0], [1, 2, 0]],
44values=[100.0, 2.3, 0.5, 0.0, 4.0],
45dense_shape=[2, 3, 1])
46
47def testGradientAttribution(self):
48factors = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
49observation_values = tf.constant(
50[[[0, 100.0, 0, 0, 0], [0, 2.3, 0, 0, 0], [0, 0, 0, 0, 0]],
51[[0, 0, 0, 0.5, 0], [0.0, 0, 0, 0, 0], [0, 4.0, 0, 0, 0]]])
52indicator = tf.constant(
53[[[0, 1.0, 0, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]],
54[[0, 0, 0, 1, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]])
55last_logits = tf.reduce_sum(
56tf.reduce_sum(
57observation_values * tf.expand_dims(factors, axis=2),
58axis=2,
59keepdims=True),
60axis=1,
61keepdims=True)
62gradients = osm.compute_gradient_attribution(
63last_logits, obs_values=observation_values, indicator=indicator)
64with self.test_session() as sess:
65acutal_gradients = sess.run(tf.squeeze(gradients))
66self.assertAllClose(factors, acutal_gradients, atol=0.01)
67
68def testAttention(self):
69seq_output = [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
70[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]]
71seq_mask = [[[1.0], [0.0], [0.0]], [[1.0], [1.0], [1.0]]]
72last_output = [[1.0, 2.0], [.5, .6]]
73
74results = osm.compute_attention(
75seq_output, last_output, hidden_layer_dim=0, seq_mask=seq_mask,
76sequence_length=[1, 3])
77expected_alpha = np.array(
78[[1.0 * 1 + 2 * 2, 0, 0],
79[0.1 * .5 + .2 * .6, .3 * .5 + .4 * .6, .5 * .5 + .6 * .6]])
80exp = np.exp(expected_alpha - np.max(expected_alpha, axis=1, keepdims=True))
81expected_beta = exp / np.sum(exp, axis=1, keepdims=True)
82expected_beta = np.expand_dims(expected_beta, 2)
83expected_attn = np.sum(np.array(seq_output) * expected_beta, axis=1)
84
85with self.test_session() as sess:
86actual_attention, acutal_beta = sess.run(results)
87self.assertAllClose(expected_beta, acutal_beta, atol=0.01)
88self.assertAllClose(expected_attn, actual_attention, atol=0.01)
89
90def testIntegratedGradientAttribution(self):
91# Due to complexity of the indicator we cannot easily extend this test to
92# > 1 lab test.
93obs_values = tf.constant([[[10000.0], [15000.0], [2.0]],
94[[0.0], [100.0], [2000.0]]])
95
96# We compare these values to a linear interpolation between the second to
97# the last and the last value of the test.
98obs_values_base = tf.constant(
99[[[10000.0], [15000.0], [15000.0]], [[0.0], [100.0], [100.0]]])
100# For this test we need to select all attributions in order for consistency
101# to hold.
102indicator = tf.ones(shape=[2, 3, 1], dtype=tf.float32)
103delta_time = tf.constant(
104[[[1000], [999], [2]], [[1001], [500], [20]]], dtype=tf.float32)
105# Selected so that the attribution is only over the third time step in both
106# batch entries.
107attribution_max_delta_time = 100
108num_classes = 1
109
110diff_delta_time = tf.constant(
111[[[1000], [1], [997]], [[1001], [501], [480]]], dtype=tf.float32)
112# This is also important to not loose any time steps in the attribution.
113sequence_length = tf.constant([3, 3])
114
115# TODO(milah): Not clear why this test doesn't work for the RNN.
116def construct_logits_fn(
117unused_diff_delta_time, obs_values, unused_indicator,
118unused_sequence_length, unused_seq_mask, unused_hparams,
119reuse):
120result = tf.layers.dense(
121obs_values, num_classes, name='test1', reuse=reuse,
122activation=None) * (
123tf.expand_dims(obs_values[:, 0, :], axis=1) + 0.5)
124return result, None
125
126# First setup the weights of the RNN.
127logits, _ = construct_logits_fn(diff_delta_time, obs_values, indicator,
128sequence_length, None, None, False)
129# To verify the correctness of the attribution we compute the prediction at
130# the obs_values_base.
131base_logits, _ = construct_logits_fn(diff_delta_time, obs_values_base,
132indicator, sequence_length, None, None,
133True)
134
135# Set high for increased precision of the approximation.
136num_steps = 100
137hparams = contrib_training.HParams(
138sequence_prediction=True,
139use_rnn_attention=False,
140path_integrated_gradients_num_steps=num_steps,
141attribution_max_delta_time=attribution_max_delta_time)
142gradients = osm.compute_path_integrated_gradient_attribution(
143obs_values, indicator, diff_delta_time, delta_time, sequence_length,
144None, hparams, construct_logits_fn)
145with self.test_session() as sess:
146sess.run(tf.global_variables_initializer())
147actual_logits = sess.run(logits)
148actual_base_logits = sess.run(base_logits)
149actual_gradients = sess.run(gradients)
150self.assertAllClose(
151actual_logits - actual_base_logits, actual_gradients, atol=0.001)
152
153def testLastObservations(self):
154obs_values = tf.constant(
155[[[0, 100.0, 0, 0, 0], [0, 2.3, 0, 0, 0], [-1.0, 0, 0, 0, 0]],
156[[0, 0, 0, 0.5, 0], [0.0, 0, 0, 0, 0], [0, 4.0, 0, 0, 0]]])
157indicator = tf.constant(
158[[[0, 1.0, 0, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]],
159[[0, 0, 0, 1, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]])
160
161delta_time = tf.constant([
162[[1000], [1001], [2]], # the last event is too new.
163[[10], [20], [22]]
164])
165attribution_max_delta_time = 10
166
167expected_result = [[[0, 2.3, 0, 0, 0]], [[0, 4.0, 0, 0.5, 0]]]
168
169last_vals = osm._most_recent_obs_value(
170obs_values, indicator, delta_time, attribution_max_delta_time)
171with self.test_session() as sess:
172actual_last_vals = sess.run(last_vals)
173self.assertAllClose(expected_result, actual_last_vals, atol=0.01)
174
175def testGradientPredictions(self):
176logits = [[[0.0], [100.0], [200.0]], [[1000.0], [100.0], [5.0]]]
177
178delta_time = tf.constant([
179[[1000000], [1000001], [20]], # first two events are too old.
180[[10], [20], [20]]
181])
182
183seq_mask = tf.constant([
184[[1.0], [1.0], [1.0]],
185[[1.0], [1.0], [0.0]] # last event is padded
186])
187
188predictions = osm._predictions_for_gradients(
189logits, seq_mask, delta_time, attribution_max_delta_time=100,
190averaged=False)
191
192avg_predictions = osm._predictions_for_gradients(
193logits, seq_mask, delta_time, attribution_max_delta_time=100,
194averaged=True)
195expected_predictions = [[[200.0]], [[1100.0]]]
196avg_expected_predictions = [[[200.0]], [[550.0]]]
197
198with self.test_session() as sess:
199actual_pred, = sess.run([predictions])
200self.assertAllClose(expected_predictions, actual_pred)
201
202avg_actual_pred, = sess.run([avg_predictions])
203self.assertAllClose(avg_expected_predictions, avg_actual_pred)
204
205def testAttribution(self):
206"""Low-level test for the correctness of compute_attribution."""
207logits = [[[0.0], [100.0], [200.0]], [[-1000.0], [100.0], [5.0]]]
208
209delta_time = tf.constant([
210[[1000000], [1000001], [20]], # first two events are too old.
211[[10], [9], [8]]
212])
213
214sequence_feature_map = {
215'obs_vals': self.observation_values,
216'deltaTime': delta_time
217}
218
219seq_mask = tf.constant([
220[[1.0], [1.0], [0.0]], # last event is padded
221[[1.0], [1.0], [1.0]]
222])
223
224attribution_threshold = 0.01
225
226expected_ixs = [1, 0, 1, 0]
227
228def _sigmoid(x):
229return math.exp(x) / (1 + math.exp(x))
230
231expected_values = [
232_sigmoid(logits[expected_ixs[0]][expected_ixs[2]][expected_ixs[3]]) -
233_sigmoid(logits[expected_ixs[0]][expected_ixs[2] - 1][expected_ixs[3]])
234]
235
236attribution = osm.compute_prediction_diff_attribution(logits)
237attribution_dict = osm.convert_attribution(
238attribution, sequence_feature_map, seq_mask, delta_time,
239attribution_threshold, 12 * 60 * 60)
240self.assertEqual(
241set(sequence_feature_map.keys()), set(attribution_dict.keys()))
242with self.test_session() as sess:
243actual_attr_val, actual_attr_time = sess.run(
244[attribution_dict['obs_vals'], attribution_dict['deltaTime']])
245self.assertAllClose([expected_ixs], actual_attr_val.indices)
246self.assertAllClose(expected_values, actual_attr_val.values)
247self.assertAllClose([expected_ixs], actual_attr_time.indices)
248self.assertAllClose(expected_values, actual_attr_time.values)
249
250def testCombine(self):
251"""Low-level test for the results of combine_observation_code_and_values."""
252
253observation_code_ids = tf.SparseTensor(
254indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 2, 0]],
255values=tf.constant([1, 1, 3, 0, 1], dtype=tf.int64),
256dense_shape=[2, 3, 1])
257
258vocab_size = 5
259
260expected_result = [[[0, 100.0, 0, 0, 0], [0, 2.3, 0, 0, 0],
261[0, 0, 0, 0, 0]], [[0, 0, 0, 0.5, 0], [0.0, 0, 0, 0, 0],
262[0, 4.0, 0, 0, 0]]]
263expected_indicator = [[[0, 1.0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]],
264[[0, 0, 0, 1, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]]
265
266acutal_result, acutal_indicator = osm.combine_observation_code_and_values(
267observation_code_ids=observation_code_ids,
268observation_values=self.observation_values,
269vocab_size=vocab_size,
270mode=tf_estimator.ModeKeys.TRAIN,
271normalize=False,
272momentum=0.9,
273min_value=-10000000,
274max_value=10000000)
275with self.test_session() as sess:
276a_result, a_indicator = sess.run([acutal_result, acutal_indicator])
277self.assertAllClose(expected_result, a_result, atol=0.01)
278self.assertAllClose(expected_indicator, a_indicator, atol=0.01)
279
280def testRnnInput(self):
281observation_values = tf.SparseTensor(
282indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0],
283[1, 0, 0], [1, 1, 0], [1, 2, 0]],
284values=[100.0, 2.3, 9999999.0, 0.5, 0.0, 4.0],
285dense_shape=[2, 3, 1])
286observation_code_ids = tf.SparseTensor(
287indices=observation_values.indices,
288values=['loinc:2', 'loinc:1', 'loinc:2',
289'loinc:1', 'MISSING', 'loinc:1'],
290dense_shape=observation_values.dense_shape)
291delta_time, obs_values, indicator = osm.construct_input(
292{
293'Observation.code':
294observation_code_ids,
295'Observation.valueQuantity.value':
296observation_values,
297'deltaTime':
298tf.constant([[[2 * 60 * 60], [3 * 60 * 60], [0]],
299[[1 * 60 * 60], [3 * 60 * 60], [6 * 60 * 60]]],
300dtype=tf.int64)
301}, ['loinc:1', 'loinc:2', 'MISSING'],
302'Observation.code',
303'Observation.valueQuantity.value',
304mode=tf_estimator.ModeKeys.TRAIN,
305normalize=False,
306momentum=0.9,
307min_value=-10000000,
308max_value=10000000,
309input_keep_prob=1.0)
310
311result = tf.concat([delta_time, indicator, obs_values], axis=2)
312
313expected_result = [
314[[0, 0, 1, 0, 0, 100, 0],
315[-1, 1, 0, 0, 2.3, 0, 0],
316# value 9999999.0 was filtered.
317[3, 0, 0, 0, 0, 0, 0]
318],
319[[0, 1, 0, 0, 0.5, 0, 0],
320[-2, 0, 0, 1, 0, 0, 0],
321[-3, 1, 0, 0, 4.0, 0, 0]]
322]
323
324with self.test_session() as sess:
325sess.run(tf.tables_initializer())
326actual_result = sess.run(result)
327print(actual_result)
328self.assertAllClose(expected_result, actual_result, atol=0.01)
329
330def testEmptyRnnInput(self):
331observation_values = tf.SparseTensor(
332indices=tf.reshape(tf.constant([], dtype=tf.int64), shape=[0, 3]),
333values=tf.constant([], dtype=tf.float32),
334dense_shape=[2, 0, 1])
335observation_code_ids = tf.SparseTensor(
336indices=observation_values.indices,
337values=tf.constant([], dtype=tf.string),
338dense_shape=observation_values.dense_shape)
339delta_time, obs_values, indicator = osm.construct_input(
340{
341'Observation.code':
342observation_code_ids,
343'Observation.valueQuantity.value':
344observation_values,
345'deltaTime':
346tf.reshape(tf.constant([[], []], dtype=tf.int64), [2, 0, 1])
347}, ['loinc:1', 'loinc:2', 'MISSING'],
348'Observation.code',
349'Observation.valueQuantity.value',
350mode=tf_estimator.ModeKeys.TRAIN,
351normalize=False,
352momentum=0.9,
353min_value=-10000000,
354max_value=10000000,
355input_keep_prob=1.0)
356
357result = tf.concat([delta_time, indicator, obs_values], axis=2)
358
359with self.test_session() as sess:
360sess.run(tf.tables_initializer())
361actual_result, = sess.run([tf.shape(result)])
362self.assertAllClose([2, 0, 7], actual_result)
363
364@parameterized.parameters(
365(True, True, False, False, False, False, False, 0, 0.0),
366(False, False, True, False, False, False, False, 0, 0.0),
367(True, False, False, True, False, False, False, 0, 0.0),
368(False, False, False, False, True, False, False, 0, 0.0),
369(False, False, False, False, False, True, False, 0, 0.0),
370(True, True, True, True, True, True, False, 0, 0.0),
371(True, True, True, True, True, True, False, 0, 1.0),
372(False, False, False, False, False, False, True, 0, 0.0),
373(False, True, False, False, True, False, True, 5, 0.0),
374)
375def testBasicModelFn(self, sequence_prediction, include_gradients,
376include_gradients_sum_time, include_gradients_avg_time,
377include_path_integrated_gradients,
378include_diff_sequence_prediction, use_rnn_attention,
379attention_hidden_layer_dim, volatility_loss_factor):
380"""This high-level tests ensures there are no errors during training.
381
382It also checks that the loss is decreasing.
383
384Args:
385sequence_prediction: Whether to consider the recent predictions in the
386loss or only the most last prediction.
387include_gradients: Whether to generate attribution with the
388gradients of the last predictions.
389include_gradients_sum_time: Whether to generate attribution
390with the gradients of the sum of the predictions over time.
391include_gradients_avg_time: Whether to generate attribution
392with the gradients of the average of the predictions over time.
393include_path_integrated_gradients: Whether to generate
394attribution with the integrated gradients of last predictions compared
395to their most recent values before attribution_max_delta_time.
396include_diff_sequence_prediction: Whether to
397generate attribution from the difference of consecutive predictions.
398use_rnn_attention: Whether to use attention for the RNN.
399attention_hidden_layer_dim: If use_rnn_attention what the dimensionality
400of a hidden layer should be (or 0 if none) of last output and
401intermediates before multiplying to obtain a weight.
402volatility_loss_factor: Include the sum of the changes in predictions
403across the sequence in the loss multiplied by this factor.
404"""
405num_steps = 2
406hparams = contrib_training.HParams(
407batch_size=2,
408learning_rate=0.008,
409sequence_features=[
410'deltaTime', 'Observation.code', 'Observation.valueQuantity.value'
411],
412categorical_values=['loinc:1', 'loinc:2', 'MISSING'],
413categorical_seq_feature='Observation.code',
414context_features=['sequenceLength'],
415feature_value='Observation.valueQuantity.value',
416label_key='label.in_hospital_death',
417attribution_threshold=-1.0,
418rnn_size=6,
419variational_recurrent_keep_prob=1.1,
420variational_input_keep_prob=1.1,
421variational_output_keep_prob=1.1,
422sequence_prediction=sequence_prediction,
423time_decayed=False,
424normalize=True,
425momentum=0.9,
426min_value=-1000.0,
427max_value=1000.0,
428volatility_loss_factor=volatility_loss_factor,
429attribution_max_delta_time=100000,
430input_keep_prob=1.0,
431include_sequence_prediction=sequence_prediction,
432include_gradients_attribution=include_gradients,
433include_gradients_sum_time_attribution=include_gradients_sum_time,
434include_gradients_avg_time_attribution=include_gradients_avg_time,
435include_path_integrated_gradients_attribution=(
436include_path_integrated_gradients),
437include_diff_sequence_prediction_attribution=(
438include_diff_sequence_prediction),
439use_rnn_attention=use_rnn_attention,
440attention_hidden_layer_dim=attention_hidden_layer_dim,
441path_integrated_gradients_num_steps=10,
442)
443observation_values = tf.SparseTensor(
444indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0],
445[1, 0, 0], [1, 1, 0], [1, 2, 0]],
446values=[100.0, 2.3, 9999999.0, 0.5, 0.0, 4.0],
447dense_shape=[2, 3, 1])
448model = osm.ObservationSequenceModel()
449model_fn = model.create_model_fn(hparams)
450features = {
451input_fn.CONTEXT_KEY_PREFIX + 'sequenceLength':
452tf.constant([[2], [3]], dtype=tf.int64),
453input_fn.SEQUENCE_KEY_PREFIX + 'Observation.code':
454tf.SparseTensor(
455indices=observation_values.indices,
456values=[
457'loinc:2', 'loinc:1', 'loinc:2', 'loinc:1', 'MISSING',
458'loinc:1'
459],
460dense_shape=observation_values.dense_shape),
461input_fn.SEQUENCE_KEY_PREFIX + 'Observation.valueQuantity.value':
462observation_values,
463input_fn.SEQUENCE_KEY_PREFIX + 'deltaTime':
464tf.constant([[[1], [2], [0]], [[1], [3], [4]]], dtype=tf.int64)
465}
466label_key = 'label.in_hospital_death'
467labels = {label_key: tf.constant([[1.0], [0.0]], dtype=tf.float32)}
468with tf.variable_scope('test'):
469model_fn_ops_train = model_fn(features, labels,
470tf_estimator.ModeKeys.TRAIN)
471with tf.variable_scope('test', reuse=True):
472features[input_fn.CONTEXT_KEY_PREFIX + 'label.in_hospital_death'
473] = tf.SparseTensor(indices=[[0, 0]], values=['expired'],
474dense_shape=[2, 1])
475model_fn_ops_eval = model_fn(
476features, labels=None, mode=tf_estimator.ModeKeys.PREDICT)
477
478with self.test_session() as sess:
479sess.run(tf.global_variables_initializer())
480sess.run(tf.tables_initializer())
481# Test train.
482for i in range(num_steps):
483loss, _ = sess.run(
484[model_fn_ops_train.loss, model_fn_ops_train.train_op])
485if i == 0:
486initial_loss = loss
487self.assertLess(loss, initial_loss)
488# Test infer.
489sess.run(model_fn_ops_eval.predictions)
490
491if __name__ == '__main__':
492absltest.main()
493