google-research

Форк
0
/
observation_sequence_model_test.py 
492 строки · 19.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
"""Tests for observation sequence model."""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
import math
24

25
from absl.testing import absltest
26
from absl.testing import parameterized
27
import numpy as np
28

29
from six.moves import range
30
import tensorflow.compat.v1 as tf
31
from tensorflow.compat.v1 import estimator as tf_estimator
32
from explaining_risk_increase import input_fn
33
from explaining_risk_increase import observation_sequence_model as osm
34
from tensorflow.contrib import training as contrib_training
35

36

37
class ObservationSequenceTest(tf.test.TestCase, parameterized.TestCase):
38

39
  def setUp(self):
40
    super(ObservationSequenceTest, self).setUp()
41
    self.observation_values = tf.SparseTensor(
42
        indices=[[0, 0, 0], [0, 1, 0],
43
                 [1, 0, 0], [1, 1, 0], [1, 2, 0]],
44
        values=[100.0, 2.3, 0.5, 0.0, 4.0],
45
        dense_shape=[2, 3, 1])
46

47
  def testGradientAttribution(self):
48
    factors = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
49
    observation_values = tf.constant(
50
        [[[0, 100.0, 0, 0, 0], [0, 2.3, 0, 0, 0], [0, 0, 0, 0, 0]],
51
         [[0, 0, 0, 0.5, 0], [0.0, 0, 0, 0, 0], [0, 4.0, 0, 0, 0]]])
52
    indicator = tf.constant(
53
        [[[0, 1.0, 0, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]],
54
         [[0, 0, 0, 1, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]])
55
    last_logits = tf.reduce_sum(
56
        tf.reduce_sum(
57
            observation_values * tf.expand_dims(factors, axis=2),
58
            axis=2,
59
            keepdims=True),
60
        axis=1,
61
        keepdims=True)
62
    gradients = osm.compute_gradient_attribution(
63
        last_logits, obs_values=observation_values, indicator=indicator)
64
    with self.test_session() as sess:
65
      acutal_gradients = sess.run(tf.squeeze(gradients))
66
      self.assertAllClose(factors, acutal_gradients, atol=0.01)
67

68
  def testAttention(self):
69
    seq_output = [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
70
                  [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]]
71
    seq_mask = [[[1.0], [0.0], [0.0]], [[1.0], [1.0], [1.0]]]
72
    last_output = [[1.0, 2.0], [.5, .6]]
73

74
    results = osm.compute_attention(
75
        seq_output, last_output, hidden_layer_dim=0, seq_mask=seq_mask,
76
        sequence_length=[1, 3])
77
    expected_alpha = np.array(
78
        [[1.0 * 1 + 2 * 2, 0, 0],
79
         [0.1 * .5 + .2 * .6, .3 * .5 + .4 * .6, .5 * .5 + .6 * .6]])
80
    exp = np.exp(expected_alpha - np.max(expected_alpha, axis=1, keepdims=True))
81
    expected_beta = exp / np.sum(exp, axis=1, keepdims=True)
82
    expected_beta = np.expand_dims(expected_beta, 2)
83
    expected_attn = np.sum(np.array(seq_output) * expected_beta, axis=1)
84

85
    with self.test_session() as sess:
86
      actual_attention, acutal_beta = sess.run(results)
87
      self.assertAllClose(expected_beta, acutal_beta, atol=0.01)
88
      self.assertAllClose(expected_attn, actual_attention, atol=0.01)
89

90
  def testIntegratedGradientAttribution(self):
91
    # Due to complexity of the indicator we cannot easily extend this test to
92
    # > 1 lab test.
93
    obs_values = tf.constant([[[10000.0], [15000.0], [2.0]],
94
                              [[0.0], [100.0], [2000.0]]])
95

96
    # We compare these values to a linear interpolation between the second to
97
    # the last and the last value of the test.
98
    obs_values_base = tf.constant(
99
        [[[10000.0], [15000.0], [15000.0]], [[0.0], [100.0], [100.0]]])
100
    # For this test we need to select all attributions in order for consistency
101
    # to hold.
102
    indicator = tf.ones(shape=[2, 3, 1], dtype=tf.float32)
103
    delta_time = tf.constant(
104
        [[[1000], [999], [2]], [[1001], [500], [20]]], dtype=tf.float32)
105
    # Selected so that the attribution is only over the third time step in both
106
    # batch entries.
107
    attribution_max_delta_time = 100
108
    num_classes = 1
109

110
    diff_delta_time = tf.constant(
111
        [[[1000], [1], [997]], [[1001], [501], [480]]], dtype=tf.float32)
112
    # This is also important to not loose any time steps in the attribution.
113
    sequence_length = tf.constant([3, 3])
114

115
    # TODO(milah): Not clear why this test doesn't work for the RNN.
116
    def construct_logits_fn(
117
        unused_diff_delta_time, obs_values, unused_indicator,
118
        unused_sequence_length, unused_seq_mask, unused_hparams,
119
        reuse):
120
      result = tf.layers.dense(
121
          obs_values, num_classes, name='test1', reuse=reuse,
122
          activation=None) * (
123
              tf.expand_dims(obs_values[:, 0, :], axis=1) + 0.5)
124
      return result, None
125

126
    # First setup the weights of the RNN.
127
    logits, _ = construct_logits_fn(diff_delta_time, obs_values, indicator,
128
                                    sequence_length, None, None, False)
129
    # To verify the correctness of the attribution we compute the prediction at
130
    # the obs_values_base.
131
    base_logits, _ = construct_logits_fn(diff_delta_time, obs_values_base,
132
                                         indicator, sequence_length, None, None,
133
                                         True)
134

135
    # Set high for increased precision of the approximation.
136
    num_steps = 100
137
    hparams = contrib_training.HParams(
138
        sequence_prediction=True,
139
        use_rnn_attention=False,
140
        path_integrated_gradients_num_steps=num_steps,
141
        attribution_max_delta_time=attribution_max_delta_time)
142
    gradients = osm.compute_path_integrated_gradient_attribution(
143
        obs_values, indicator, diff_delta_time, delta_time, sequence_length,
144
        None, hparams, construct_logits_fn)
145
    with self.test_session() as sess:
146
      sess.run(tf.global_variables_initializer())
147
      actual_logits = sess.run(logits)
148
      actual_base_logits = sess.run(base_logits)
149
      actual_gradients = sess.run(gradients)
150
      self.assertAllClose(
151
          actual_logits - actual_base_logits, actual_gradients, atol=0.001)
152

153
  def testLastObservations(self):
154
    obs_values = tf.constant(
155
        [[[0, 100.0, 0, 0, 0], [0, 2.3, 0, 0, 0], [-1.0, 0, 0, 0, 0]],
156
         [[0, 0, 0, 0.5, 0], [0.0, 0, 0, 0, 0], [0, 4.0, 0, 0, 0]]])
157
    indicator = tf.constant(
158
        [[[0, 1.0, 0, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]],
159
         [[0, 0, 0, 1, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]])
160

161
    delta_time = tf.constant([
162
        [[1000], [1001], [2]],  # the last event is too new.
163
        [[10], [20], [22]]
164
    ])
165
    attribution_max_delta_time = 10
166

167
    expected_result = [[[0, 2.3, 0, 0, 0]], [[0, 4.0, 0, 0.5, 0]]]
168

169
    last_vals = osm._most_recent_obs_value(
170
        obs_values, indicator, delta_time, attribution_max_delta_time)
171
    with self.test_session() as sess:
172
      actual_last_vals = sess.run(last_vals)
173
      self.assertAllClose(expected_result, actual_last_vals, atol=0.01)
174

175
  def testGradientPredictions(self):
176
    logits = [[[0.0], [100.0], [200.0]], [[1000.0], [100.0], [5.0]]]
177

178
    delta_time = tf.constant([
179
        [[1000000], [1000001], [20]],  # first two events are too old.
180
        [[10], [20], [20]]
181
    ])
182

183
    seq_mask = tf.constant([
184
        [[1.0], [1.0], [1.0]],
185
        [[1.0], [1.0], [0.0]]  # last event is padded
186
    ])
187

188
    predictions = osm._predictions_for_gradients(
189
        logits, seq_mask, delta_time, attribution_max_delta_time=100,
190
        averaged=False)
191

192
    avg_predictions = osm._predictions_for_gradients(
193
        logits, seq_mask, delta_time, attribution_max_delta_time=100,
194
        averaged=True)
195
    expected_predictions = [[[200.0]], [[1100.0]]]
196
    avg_expected_predictions = [[[200.0]], [[550.0]]]
197

198
    with self.test_session() as sess:
199
      actual_pred, = sess.run([predictions])
200
      self.assertAllClose(expected_predictions, actual_pred)
201

202
      avg_actual_pred, = sess.run([avg_predictions])
203
      self.assertAllClose(avg_expected_predictions, avg_actual_pred)
204

205
  def testAttribution(self):
206
    """Low-level test for the correctness of compute_attribution."""
207
    logits = [[[0.0], [100.0], [200.0]], [[-1000.0], [100.0], [5.0]]]
208

209
    delta_time = tf.constant([
210
        [[1000000], [1000001], [20]],  # first two events are too old.
211
        [[10], [9], [8]]
212
    ])
213

214
    sequence_feature_map = {
215
        'obs_vals': self.observation_values,
216
        'deltaTime': delta_time
217
    }
218

219
    seq_mask = tf.constant([
220
        [[1.0], [1.0], [0.0]],  # last event is padded
221
        [[1.0], [1.0], [1.0]]
222
    ])
223

224
    attribution_threshold = 0.01
225

226
    expected_ixs = [1, 0, 1, 0]
227

228
    def _sigmoid(x):
229
      return math.exp(x) / (1 + math.exp(x))
230

231
    expected_values = [
232
        _sigmoid(logits[expected_ixs[0]][expected_ixs[2]][expected_ixs[3]]) -
233
        _sigmoid(logits[expected_ixs[0]][expected_ixs[2] - 1][expected_ixs[3]])
234
    ]
235

236
    attribution = osm.compute_prediction_diff_attribution(logits)
237
    attribution_dict = osm.convert_attribution(
238
        attribution, sequence_feature_map, seq_mask, delta_time,
239
        attribution_threshold, 12 * 60 * 60)
240
    self.assertEqual(
241
        set(sequence_feature_map.keys()), set(attribution_dict.keys()))
242
    with self.test_session() as sess:
243
      actual_attr_val, actual_attr_time = sess.run(
244
          [attribution_dict['obs_vals'], attribution_dict['deltaTime']])
245
      self.assertAllClose([expected_ixs], actual_attr_val.indices)
246
      self.assertAllClose(expected_values, actual_attr_val.values)
247
      self.assertAllClose([expected_ixs], actual_attr_time.indices)
248
      self.assertAllClose(expected_values, actual_attr_time.values)
249

250
  def testCombine(self):
251
    """Low-level test for the results of combine_observation_code_and_values."""
252

253
    observation_code_ids = tf.SparseTensor(
254
        indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 2, 0]],
255
        values=tf.constant([1, 1, 3, 0, 1], dtype=tf.int64),
256
        dense_shape=[2, 3, 1])
257

258
    vocab_size = 5
259

260
    expected_result = [[[0, 100.0, 0, 0, 0], [0, 2.3, 0, 0, 0],
261
                        [0, 0, 0, 0, 0]], [[0, 0, 0, 0.5, 0], [0.0, 0, 0, 0, 0],
262
                                           [0, 4.0, 0, 0, 0]]]
263
    expected_indicator = [[[0, 1.0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]],
264
                          [[0, 0, 0, 1, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]]
265

266
    acutal_result, acutal_indicator = osm.combine_observation_code_and_values(
267
        observation_code_ids=observation_code_ids,
268
        observation_values=self.observation_values,
269
        vocab_size=vocab_size,
270
        mode=tf_estimator.ModeKeys.TRAIN,
271
        normalize=False,
272
        momentum=0.9,
273
        min_value=-10000000,
274
        max_value=10000000)
275
    with self.test_session() as sess:
276
      a_result, a_indicator = sess.run([acutal_result, acutal_indicator])
277
      self.assertAllClose(expected_result, a_result, atol=0.01)
278
      self.assertAllClose(expected_indicator, a_indicator, atol=0.01)
279

280
  def testRnnInput(self):
281
    observation_values = tf.SparseTensor(
282
        indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0],
283
                 [1, 0, 0], [1, 1, 0], [1, 2, 0]],
284
        values=[100.0, 2.3, 9999999.0, 0.5, 0.0, 4.0],
285
        dense_shape=[2, 3, 1])
286
    observation_code_ids = tf.SparseTensor(
287
        indices=observation_values.indices,
288
        values=['loinc:2', 'loinc:1', 'loinc:2',
289
                'loinc:1', 'MISSING', 'loinc:1'],
290
        dense_shape=observation_values.dense_shape)
291
    delta_time, obs_values, indicator = osm.construct_input(
292
        {
293
            'Observation.code':
294
                observation_code_ids,
295
            'Observation.valueQuantity.value':
296
                observation_values,
297
            'deltaTime':
298
                tf.constant([[[2 * 60 * 60], [3 * 60 * 60], [0]],
299
                             [[1 * 60 * 60], [3 * 60 * 60], [6 * 60 * 60]]],
300
                            dtype=tf.int64)
301
        }, ['loinc:1', 'loinc:2', 'MISSING'],
302
        'Observation.code',
303
        'Observation.valueQuantity.value',
304
        mode=tf_estimator.ModeKeys.TRAIN,
305
        normalize=False,
306
        momentum=0.9,
307
        min_value=-10000000,
308
        max_value=10000000,
309
        input_keep_prob=1.0)
310

311
    result = tf.concat([delta_time, indicator, obs_values], axis=2)
312

313
    expected_result = [
314
        [[0, 0, 1, 0, 0, 100, 0],
315
         [-1, 1, 0, 0, 2.3, 0, 0],
316
         # value 9999999.0 was filtered.
317
         [3, 0, 0, 0, 0, 0, 0]
318
        ],
319
        [[0, 1, 0, 0, 0.5, 0, 0],
320
         [-2, 0, 0, 1, 0, 0, 0],
321
         [-3, 1, 0, 0, 4.0, 0, 0]]
322
    ]
323

324
    with self.test_session() as sess:
325
      sess.run(tf.tables_initializer())
326
      actual_result = sess.run(result)
327
      print(actual_result)
328
      self.assertAllClose(expected_result, actual_result, atol=0.01)
329

330
  def testEmptyRnnInput(self):
331
    observation_values = tf.SparseTensor(
332
        indices=tf.reshape(tf.constant([], dtype=tf.int64), shape=[0, 3]),
333
        values=tf.constant([], dtype=tf.float32),
334
        dense_shape=[2, 0, 1])
335
    observation_code_ids = tf.SparseTensor(
336
        indices=observation_values.indices,
337
        values=tf.constant([], dtype=tf.string),
338
        dense_shape=observation_values.dense_shape)
339
    delta_time, obs_values, indicator = osm.construct_input(
340
        {
341
            'Observation.code':
342
                observation_code_ids,
343
            'Observation.valueQuantity.value':
344
                observation_values,
345
            'deltaTime':
346
                tf.reshape(tf.constant([[], []], dtype=tf.int64), [2, 0, 1])
347
        }, ['loinc:1', 'loinc:2', 'MISSING'],
348
        'Observation.code',
349
        'Observation.valueQuantity.value',
350
        mode=tf_estimator.ModeKeys.TRAIN,
351
        normalize=False,
352
        momentum=0.9,
353
        min_value=-10000000,
354
        max_value=10000000,
355
        input_keep_prob=1.0)
356

357
    result = tf.concat([delta_time, indicator, obs_values], axis=2)
358

359
    with self.test_session() as sess:
360
      sess.run(tf.tables_initializer())
361
      actual_result, = sess.run([tf.shape(result)])
362
      self.assertAllClose([2, 0, 7], actual_result)
363

364
  @parameterized.parameters(
365
      (True, True, False, False, False, False, False, 0, 0.0),
366
      (False, False, True, False, False, False, False, 0, 0.0),
367
      (True, False, False, True, False, False, False, 0, 0.0),
368
      (False, False, False, False, True, False, False, 0, 0.0),
369
      (False, False, False, False, False, True, False, 0, 0.0),
370
      (True, True, True, True, True, True, False, 0, 0.0),
371
      (True, True, True, True, True, True, False, 0, 1.0),
372
      (False, False, False, False, False, False, True, 0, 0.0),
373
      (False, True, False, False, True, False, True, 5, 0.0),
374
  )
375
  def testBasicModelFn(self, sequence_prediction, include_gradients,
376
                       include_gradients_sum_time, include_gradients_avg_time,
377
                       include_path_integrated_gradients,
378
                       include_diff_sequence_prediction, use_rnn_attention,
379
                       attention_hidden_layer_dim, volatility_loss_factor):
380
    """This high-level tests ensures there are no errors during training.
381

382
    It also checks that the loss is decreasing.
383

384
    Args:
385
      sequence_prediction: Whether to consider the recent predictions in the
386
        loss or only the most last prediction.
387
      include_gradients: Whether to generate attribution with the
388
        gradients of the last predictions.
389
      include_gradients_sum_time: Whether to generate attribution
390
        with the gradients of the sum of the predictions over time.
391
      include_gradients_avg_time: Whether to generate attribution
392
        with the gradients of the average of the predictions over time.
393
      include_path_integrated_gradients: Whether to generate
394
        attribution with the integrated gradients of last predictions compared
395
        to their most recent values before attribution_max_delta_time.
396
      include_diff_sequence_prediction: Whether to
397
        generate attribution from the difference of consecutive predictions.
398
      use_rnn_attention: Whether to use attention for the RNN.
399
      attention_hidden_layer_dim: If use_rnn_attention what the dimensionality
400
        of a hidden layer should be (or 0 if none) of last output and
401
        intermediates before multiplying to obtain a weight.
402
      volatility_loss_factor: Include the sum of the changes in predictions
403
        across the sequence in the loss multiplied by this factor.
404
    """
405
    num_steps = 2
406
    hparams = contrib_training.HParams(
407
        batch_size=2,
408
        learning_rate=0.008,
409
        sequence_features=[
410
            'deltaTime', 'Observation.code', 'Observation.valueQuantity.value'
411
        ],
412
        categorical_values=['loinc:1', 'loinc:2', 'MISSING'],
413
        categorical_seq_feature='Observation.code',
414
        context_features=['sequenceLength'],
415
        feature_value='Observation.valueQuantity.value',
416
        label_key='label.in_hospital_death',
417
        attribution_threshold=-1.0,
418
        rnn_size=6,
419
        variational_recurrent_keep_prob=1.1,
420
        variational_input_keep_prob=1.1,
421
        variational_output_keep_prob=1.1,
422
        sequence_prediction=sequence_prediction,
423
        time_decayed=False,
424
        normalize=True,
425
        momentum=0.9,
426
        min_value=-1000.0,
427
        max_value=1000.0,
428
        volatility_loss_factor=volatility_loss_factor,
429
        attribution_max_delta_time=100000,
430
        input_keep_prob=1.0,
431
        include_sequence_prediction=sequence_prediction,
432
        include_gradients_attribution=include_gradients,
433
        include_gradients_sum_time_attribution=include_gradients_sum_time,
434
        include_gradients_avg_time_attribution=include_gradients_avg_time,
435
        include_path_integrated_gradients_attribution=(
436
            include_path_integrated_gradients),
437
        include_diff_sequence_prediction_attribution=(
438
            include_diff_sequence_prediction),
439
        use_rnn_attention=use_rnn_attention,
440
        attention_hidden_layer_dim=attention_hidden_layer_dim,
441
        path_integrated_gradients_num_steps=10,
442
    )
443
    observation_values = tf.SparseTensor(
444
        indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0],
445
                 [1, 0, 0], [1, 1, 0], [1, 2, 0]],
446
        values=[100.0, 2.3, 9999999.0, 0.5, 0.0, 4.0],
447
        dense_shape=[2, 3, 1])
448
    model = osm.ObservationSequenceModel()
449
    model_fn = model.create_model_fn(hparams)
450
    features = {
451
        input_fn.CONTEXT_KEY_PREFIX + 'sequenceLength':
452
            tf.constant([[2], [3]], dtype=tf.int64),
453
        input_fn.SEQUENCE_KEY_PREFIX + 'Observation.code':
454
            tf.SparseTensor(
455
                indices=observation_values.indices,
456
                values=[
457
                    'loinc:2', 'loinc:1', 'loinc:2', 'loinc:1', 'MISSING',
458
                    'loinc:1'
459
                ],
460
                dense_shape=observation_values.dense_shape),
461
        input_fn.SEQUENCE_KEY_PREFIX + 'Observation.valueQuantity.value':
462
            observation_values,
463
        input_fn.SEQUENCE_KEY_PREFIX + 'deltaTime':
464
            tf.constant([[[1], [2], [0]], [[1], [3], [4]]], dtype=tf.int64)
465
    }
466
    label_key = 'label.in_hospital_death'
467
    labels = {label_key: tf.constant([[1.0], [0.0]], dtype=tf.float32)}
468
    with tf.variable_scope('test'):
469
      model_fn_ops_train = model_fn(features, labels,
470
                                    tf_estimator.ModeKeys.TRAIN)
471
    with tf.variable_scope('test', reuse=True):
472
      features[input_fn.CONTEXT_KEY_PREFIX + 'label.in_hospital_death'
473
              ] = tf.SparseTensor(indices=[[0, 0]], values=['expired'],
474
                                  dense_shape=[2, 1])
475
      model_fn_ops_eval = model_fn(
476
          features, labels=None, mode=tf_estimator.ModeKeys.PREDICT)
477

478
    with self.test_session() as sess:
479
      sess.run(tf.global_variables_initializer())
480
      sess.run(tf.tables_initializer())
481
      # Test train.
482
      for i in range(num_steps):
483
        loss, _ = sess.run(
484
            [model_fn_ops_train.loss, model_fn_ops_train.train_op])
485
        if i == 0:
486
          initial_loss = loss
487
      self.assertLess(loss, initial_loss)
488
      # Test infer.
489
      sess.run(model_fn_ops_eval.predictions)
490

491
if __name__ == '__main__':
492
  absltest.main()
493

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.