google-research
72 строки · 2.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16#!/usr/bin/python
17r"""Reject graphs based on importance to produce a uniform sample set.
18
19Usage:
20prefix=3_COFH
21./reject_to_uniform.py \
22--in_file=weighted/${prefix}.graphml \
23--out_file=uniform/${prefix}.graphml
24"""
25
26from absl import app27from absl import flags28
29from graph_sampler import graph_io30from graph_sampler import molecule_sampler31
32FLAGS = flags.FLAGS33
34flags.DEFINE_string('in_file', None, 'Input file path.')35flags.DEFINE_string('out_file', None, 'Output file path.')36flags.DEFINE_string('seed', None, 'Seed used for random number generation.')37
38
39def main(argv):40if len(argv) > 1:41raise RuntimeError(f'Unexpected arguments: {argv[1:]}')42
43input_stats = graph_io.get_stats(FLAGS.in_file)44max_importance = input_stats['max_final_importance']45with open(FLAGS.in_file) as input_file:46rejector = molecule_sampler.RejectToUniform(47base_iter=graph_io.graph_reader(input_file),48max_importance=max_importance,49rng_seed=FLAGS.seed)50with open(FLAGS.out_file, 'w') as output_file:51for graph in rejector:52graph_io.write_graph(graph, output_file)53if rejector.num_accepted % 10000 == 0:54acc = rejector.num_accepted55proc = rejector.num_processed56print(f'Accepted {acc}/{proc}: {acc / proc * 100:.2f}%')57
58output_stats = dict(59num_samples=rejector.num_accepted,60estimated_num_graphs=input_stats['estimated_num_graphs'],61rng_seed=rejector.rng_seed)62graph_io.write_stats(output_stats, output_file)63
64acc = rejector.num_accepted65proc = rejector.num_processed66print(f'Done rejecting to uniform! Accepted {acc}/{proc}: '67f'{acc / proc * 100:.2f}%')68
69
70if __name__ == '__main__':71flags.mark_flags_as_required(['in_file', 'out_file'])72app.run(main)73