google-research
192 строки · 6.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for calculating dataset statistics."""
17
18import copy19import logging20
21import contextlib222import tensorflow.compat.v1 as tf23from xxx import metrics as contrib_metrics24from tensorflow.contrib import labeled_tensor as lt25
26# Google Internal
27import text_format28import gfile29
30from ..learning import data31
32logger = logging.getLogger(__name__)33
34
35def experiment_has_statistics(experiment_proto):36"""Returns True if the experiment proto has statistics."""37has_all_statistics = True38for round_proto in experiment_proto.rounds.values():39for reads in [round_proto.positive_reads, round_proto.negative_reads]:40if reads.name:41if not reads.HasField('statistics'):42has_all_statistics = False43
44for ao_proto in experiment_proto.additional_output:45if ao_proto.name:46if not ao_proto.HasField('statistics'):47has_all_statistics = False48return has_all_statistics49
50
51def compute_experiment_statistics(52experiment_proto,53input_paths,54proto_w_stats_path,55preprocess_mode=data.PREPROCESS_SKIP_ALL_ZERO_COUNTS,56max_size=None,57logdir=None,58save_stats=False):59"""Calculate the mean and standard deviation of counts from input files.60
61These statistics are used for normalization. If any statistic is missing or
62save_stats=True, compute the statistics. Save the statitics to
63proto_w_stats_path if save_stats=True.
64
65Args:
66experiment_proto: selection_pb2.Experiment describing the experiment.
67input_paths: list of strings giving paths to sstables of input examples.
68proto_w_stats_path: string path to the validation proto file with stats
69preprocess_mode: optional preprocess mode defined in the `data` module.
70max_size: optional number of examples to examine to compute statistics. By
71default, examines the entire dataset.
72logdir: optional path to a directory in which to log events.
73save_stats: optional boolean indicating whether to update all the statistics
74and save to proto_w_stats_path.
75
76Returns:
77selection_pb2.Experiment with computed statistics.
78"""
79experiment_proto = copy.deepcopy(experiment_proto)80
81has_all_statistics = True82
83all_reads = {}84for round_proto in experiment_proto.rounds.values():85for reads in [round_proto.positive_reads, round_proto.negative_reads]:86if reads.name:87all_reads[reads.name] = reads88if not reads.HasField('statistics'):89has_all_statistics = False90
91all_ao = {}92for ao_proto in experiment_proto.additional_output:93if ao_proto.name:94all_ao[ao_proto.name] = ao_proto95if not ao_proto.HasField('statistics'):96has_all_statistics = False97
98if not has_all_statistics or save_stats:99with tf.Graph().as_default():100logger.info('Setting up graph for statistics')101# we only care about outputs, which don't rely on training hyper102# parameters103hps = tf.HParams(104preprocess_mode=preprocess_mode,105kmer_k_max=0,106ratio_random_dna=0.0,107total_reads_defining_positive=0,108additional_output=','.join([109x.name for x in experiment_proto.additional_output]))110_, outputs = data.input_pipeline(111input_paths,112experiment_proto,113final_mbsz=100000,114hps=hps,115num_epochs=1,116num_threads=1)117size_op = tf.shape(outputs)[list(outputs.axes.keys()).index('batch')]118
119all_update_ops = []120all_value_ops = {}121for name in all_reads:122counts = lt.select(outputs, {'output': name})123log_counts = lt.log(counts + 1.0)124ops = {125'mean': contrib_metrics.streaming_mean(counts),126'std_dev': streaming_std(counts),127'mean_log_plus_one': contrib_metrics.streaming_mean(log_counts),128'std_dev_log_plus_one': streaming_std(log_counts),129}130value_ops, update_ops = contrib_metrics.aggregate_metric_map(ops)131all_update_ops.extend(list(update_ops.values()))132all_value_ops[name] = value_ops133
134for name in all_ao:135ao = lt.select(outputs, {'output': name})136log_ao = lt.log(ao + 1.0)137ops = {138'mean': contrib_metrics.streaming_mean(ao),139'std_dev': streaming_std(ao),140'mean_log_plus_one': contrib_metrics.streaming_mean(log_ao),141'std_dev_log_plus_one': streaming_std(log_ao),142}143value_ops, update_ops = contrib_metrics.aggregate_metric_map(ops)144all_update_ops.extend(list(update_ops.values()))145all_value_ops[name] = value_ops146
147logger.info('Running statistics ops')148sv = tf.train.Supervisor(logdir=logdir)149with sv.managed_session() as sess:150total = 0151for results in run_until_exhausted(sv, sess,152[size_op] + all_update_ops):153total += results[0]154if max_size is not None and total >= max_size:155break156all_statistics = {k: sess.run(v) for k, v in all_value_ops.items()}157
158for reads_name, reads in all_reads.items():159for name, value in all_statistics[reads_name].items():160setattr(reads.statistics, name, value.item())161
162for ao_name, ao in all_ao.items():163for name, value in all_statistics[ao_name].items():164setattr(ao.statistics, name, value.item())165
166logger.info('Computed statistics: %r', all_statistics)167
168if save_stats:169logger.info('Save the proto with statistics to %s', proto_w_stats_path)170with open('/tmp/tmp.pbtxt', 'w') as f:171f.write(text_format.MessageToString(experiment_proto))172gfile.Copy('/tmp/tmp.pbtxt', proto_w_stats_path, overwrite=True)173else:174logger.info('All the statistics exist. Nothing to compute')175return experiment_proto176
177
178def streaming_std(tensor):179
180mean_value, mean_update = contrib_metrics.streaming_mean(tensor)181mean_squared_value, mean_squared_update = contrib_metrics.streaming_mean(182tf.square(tensor))183value_op = tf.sqrt(mean_squared_value - tf.square(mean_value))184update_op = tf.group(mean_update, mean_squared_update)185return value_op, update_op186
187
188def run_until_exhausted(supervisor, session, fetches):189"""Run the given fetches until OutOfRangeError is triggered."""190with contextlib2.suppress(tf.errors.OutOfRangeError):191while not supervisor.should_stop():192yield session.run(fetches)193