google-research

Форк
0
192 строки · 6.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for calculating dataset statistics."""
17

18
import copy
19
import logging
20

21
import contextlib2
22
import tensorflow.compat.v1 as tf
23
from xxx import metrics as contrib_metrics
24
from tensorflow.contrib import labeled_tensor as lt
25

26
# Google Internal
27
import text_format
28
import gfile
29

30
from ..learning import data
31

32
logger = logging.getLogger(__name__)
33

34

35
def experiment_has_statistics(experiment_proto):
36
  """Returns True if the experiment proto has statistics."""
37
  has_all_statistics = True
38
  for round_proto in experiment_proto.rounds.values():
39
    for reads in [round_proto.positive_reads, round_proto.negative_reads]:
40
      if reads.name:
41
        if not reads.HasField('statistics'):
42
          has_all_statistics = False
43

44
  for ao_proto in experiment_proto.additional_output:
45
    if ao_proto.name:
46
      if not ao_proto.HasField('statistics'):
47
        has_all_statistics = False
48
  return has_all_statistics
49

50

51
def compute_experiment_statistics(
52
    experiment_proto,
53
    input_paths,
54
    proto_w_stats_path,
55
    preprocess_mode=data.PREPROCESS_SKIP_ALL_ZERO_COUNTS,
56
    max_size=None,
57
    logdir=None,
58
    save_stats=False):
59
  """Calculate the mean and standard deviation of counts from input files.
60

61
  These statistics are used for normalization. If any statistic is missing or
62
  save_stats=True, compute the statistics. Save the statitics to
63
  proto_w_stats_path if save_stats=True.
64

65
  Args:
66
    experiment_proto: selection_pb2.Experiment describing the experiment.
67
    input_paths: list of strings giving paths to sstables of input examples.
68
    proto_w_stats_path: string path to the validation proto file with stats
69
    preprocess_mode: optional preprocess mode defined in the `data` module.
70
    max_size: optional number of examples to examine to compute statistics. By
71
      default, examines the entire dataset.
72
    logdir: optional path to a directory in which to log events.
73
    save_stats: optional boolean indicating whether to update all the statistics
74
      and save to proto_w_stats_path.
75

76
  Returns:
77
    selection_pb2.Experiment with computed statistics.
78
  """
79
  experiment_proto = copy.deepcopy(experiment_proto)
80

81
  has_all_statistics = True
82

83
  all_reads = {}
84
  for round_proto in experiment_proto.rounds.values():
85
    for reads in [round_proto.positive_reads, round_proto.negative_reads]:
86
      if reads.name:
87
        all_reads[reads.name] = reads
88
        if not reads.HasField('statistics'):
89
          has_all_statistics = False
90

91
  all_ao = {}
92
  for ao_proto in experiment_proto.additional_output:
93
    if ao_proto.name:
94
      all_ao[ao_proto.name] = ao_proto
95
      if not ao_proto.HasField('statistics'):
96
        has_all_statistics = False
97

98
  if not has_all_statistics or save_stats:
99
    with tf.Graph().as_default():
100
      logger.info('Setting up graph for statistics')
101
      # we only care about outputs, which don't rely on training hyper
102
      # parameters
103
      hps = tf.HParams(
104
          preprocess_mode=preprocess_mode,
105
          kmer_k_max=0,
106
          ratio_random_dna=0.0,
107
          total_reads_defining_positive=0,
108
          additional_output=','.join([
109
              x.name for x in experiment_proto.additional_output]))
110
      _, outputs = data.input_pipeline(
111
          input_paths,
112
          experiment_proto,
113
          final_mbsz=100000,
114
          hps=hps,
115
          num_epochs=1,
116
          num_threads=1)
117
      size_op = tf.shape(outputs)[list(outputs.axes.keys()).index('batch')]
118

119
      all_update_ops = []
120
      all_value_ops = {}
121
      for name in all_reads:
122
        counts = lt.select(outputs, {'output': name})
123
        log_counts = lt.log(counts + 1.0)
124
        ops = {
125
            'mean': contrib_metrics.streaming_mean(counts),
126
            'std_dev': streaming_std(counts),
127
            'mean_log_plus_one': contrib_metrics.streaming_mean(log_counts),
128
            'std_dev_log_plus_one': streaming_std(log_counts),
129
        }
130
        value_ops, update_ops = contrib_metrics.aggregate_metric_map(ops)
131
        all_update_ops.extend(list(update_ops.values()))
132
        all_value_ops[name] = value_ops
133

134
      for name in all_ao:
135
        ao = lt.select(outputs, {'output': name})
136
        log_ao = lt.log(ao + 1.0)
137
        ops = {
138
            'mean': contrib_metrics.streaming_mean(ao),
139
            'std_dev': streaming_std(ao),
140
            'mean_log_plus_one': contrib_metrics.streaming_mean(log_ao),
141
            'std_dev_log_plus_one': streaming_std(log_ao),
142
        }
143
        value_ops, update_ops = contrib_metrics.aggregate_metric_map(ops)
144
        all_update_ops.extend(list(update_ops.values()))
145
        all_value_ops[name] = value_ops
146

147
      logger.info('Running statistics ops')
148
      sv = tf.train.Supervisor(logdir=logdir)
149
      with sv.managed_session() as sess:
150
        total = 0
151
        for results in run_until_exhausted(sv, sess,
152
                                           [size_op] + all_update_ops):
153
          total += results[0]
154
          if max_size is not None and total >= max_size:
155
            break
156
        all_statistics = {k: sess.run(v) for k, v in all_value_ops.items()}
157

158
      for reads_name, reads in all_reads.items():
159
        for name, value in all_statistics[reads_name].items():
160
          setattr(reads.statistics, name, value.item())
161

162
      for ao_name, ao in all_ao.items():
163
        for name, value in all_statistics[ao_name].items():
164
          setattr(ao.statistics, name, value.item())
165

166
      logger.info('Computed statistics: %r', all_statistics)
167

168
      if save_stats:
169
        logger.info('Save the proto with statistics to %s', proto_w_stats_path)
170
        with open('/tmp/tmp.pbtxt', 'w') as f:
171
          f.write(text_format.MessageToString(experiment_proto))
172
        gfile.Copy('/tmp/tmp.pbtxt', proto_w_stats_path, overwrite=True)
173
  else:
174
    logger.info('All the statistics exist. Nothing to compute')
175
  return experiment_proto
176

177

178
def streaming_std(tensor):
179

180
  mean_value, mean_update = contrib_metrics.streaming_mean(tensor)
181
  mean_squared_value, mean_squared_update = contrib_metrics.streaming_mean(
182
      tf.square(tensor))
183
  value_op = tf.sqrt(mean_squared_value - tf.square(mean_value))
184
  update_op = tf.group(mean_update, mean_squared_update)
185
  return value_op, update_op
186

187

188
def run_until_exhausted(supervisor, session, fetches):
189
  """Run the given fetches until OutOfRangeError is triggered."""
190
  with contextlib2.suppress(tf.errors.OutOfRangeError):
191
    while not supervisor.should_stop():
192
      yield session.run(fetches)
193

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.