google-research
99 строк · 3.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16#!/usr/bin/python
17r"""Program to sample molecules from a given stoichiometry.
18
19Example usage:
20prefix=3_COFH
21./sample_molecules.py --min_samples=3000 \
22--stoich_file=stoichs/${prefix}.stoich \
23--out_file=weighted/${prefix}.graphml
24"""
25
26import sys
27import timeit
28
29from absl import app
30from absl import flags
31
32from graph_sampler import graph_io
33from graph_sampler import molecule_sampler
34from graph_sampler import stoichiometry
35
36FLAGS = flags.FLAGS
37
38flags.DEFINE_string('stoich_file', None, 'Csv file with desired stoichiometry.')
39flags.DEFINE_integer('min_samples', 10000, 'Minimum number of samples.')
40flags.DEFINE_float(
41'relative_precision', 0.01,
42'Keep sampling until (std_err / estimate) is less than this number.')
43flags.DEFINE_float(
44'min_uniform_proportion', None,
45'Keep sampling until this this set of samples can be rejected down to a '
46'uniform sample containing at least this proportion of the estimated '
47'number of graphs.')
48flags.DEFINE_string('out_file', None, 'Output file path.')
49flags.DEFINE_string('seed', None, 'Seed used for random number generation.')
50
51
52def main(argv):
53if len(argv) > 1:
54raise RuntimeError(f'Unexpected arguments: {argv[1:]}')
55
56print(f'Reading stoich from: {FLAGS.stoich_file}')
57with open(FLAGS.stoich_file) as f:
58stoich = stoichiometry.read(f)
59
60mol_sampler = molecule_sampler.MoleculeSampler(
61stoich,
62min_samples=FLAGS.min_samples,
63min_uniform_proportion=FLAGS.min_uniform_proportion,
64relative_precision=FLAGS.relative_precision,
65rng_seed=FLAGS.seed)
66start_time = timeit.default_timer()
67num = 0
68
69def print_progress():
70stats = mol_sampler.stats()
71std_err_frac = stats['num_graphs_std_err'] / stats['estimated_num_graphs']
72est_proportion = (
73stats['num_after_rejection'] / stats['estimated_num_graphs'])
74print(f'Sampled {stats["num_samples"]} ({num} valid), '
75f'{timeit.default_timer() - start_time:.03f} sec, '
76f'{stats["estimated_num_graphs"]:.3E} graphs '
77f'(std err={100 * std_err_frac:.3f}%), '
78f'proportion after rejection={est_proportion:.3E}')
79sys.stdout.flush()
80
81with open(FLAGS.out_file, 'w') as out:
82for graph in mol_sampler:
83graph_io.write_graph(graph, out)
84num += 1
85if num % 10000 == 0:
86print_progress()
87
88stats = mol_sampler.stats()
89stats['elapsed time'] = timeit.default_timer() - start_time
90graph_io.write_stats(stats, out)
91
92print('Done generating molecules!')
93if num % 10000 != 0:
94print_progress()
95
96
97if __name__ == '__main__':
98flags.mark_flags_as_required(['stoich_file', 'out_file'])
99app.run(main)
100