google-research
71 строка · 2.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16#!/usr/bin/python
17r"""Generates every neutral stoichiometry with a given number of heavy atoms.
18
19Example usage:
20mkdir stoichs
21enumerate_stoichiometries.py --output_prefix=stoichs/ \
22--num_heavy=3 --heavy_elements=C,N,O,S
23"""
24
25from absl import app26from absl import flags27
28from graph_sampler import stoichiometry29
30FLAGS = flags.FLAGS31
32flags.DEFINE_integer('num_heavy', None, 'Number of non-hydrogen atoms.')33flags.DEFINE_list('heavy_elements', ['C', 'N', 'N+', 'O', 'O-', 'F'],34'Which heavy elements to use.')35flags.DEFINE_string('output_prefix', '', 'Prefix for output files.')36flags.DEFINE_list(37'valences', [],38'Valences of atom types (only required for atom types whose valence cannot '39'be inferred by rdkit, (e.g. "X=7,R=3" if you\'re using "synthetic atoms" '40'with valences 7 and 3).')41flags.DEFINE_list(42'charges', [],43'Charges of atom types (only required for atom types whose charge cannot '44'be inferred by rdkit, (e.g. "X=0,R=-1" if you\'re using "synthetic atoms" '45'with valences 0 and -1).')46
47
48def main(argv):49if len(argv) > 1:50raise RuntimeError(f'Unexpected arguments: {argv[1:]}')51FLAGS.valences = stoichiometry.parse_dict_flag(FLAGS.valences)52FLAGS.charges = stoichiometry.parse_dict_flag(FLAGS.charges)53
54count = 055for stoich in stoichiometry.enumerate_stoichiometries(FLAGS.num_heavy,56FLAGS.heavy_elements,57FLAGS.valences,58FLAGS.charges):59element_str = ''.join(stoich.to_element_list())60fn = '%s%d_%s.stoich' % (FLAGS.output_prefix, FLAGS.num_heavy, element_str)61print(element_str)62with open(fn, 'w') as f:63stoich.write(f)64count += 165
66print(f'{count} files written!')67
68
69if __name__ == '__main__':70flags.mark_flag_as_required('num_heavy')71app.run(main)72