google-research

Форк
0
/
reject_to_uniform.py 
72 строки · 2.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
#!/usr/bin/python
17
r"""Reject graphs based on importance to produce a uniform sample set.
18

19
Usage:
20
prefix=3_COFH
21
./reject_to_uniform.py \
22
    --in_file=weighted/${prefix}.graphml \
23
    --out_file=uniform/${prefix}.graphml
24
"""
25

26
from absl import app
27
from absl import flags
28

29
from graph_sampler import graph_io
30
from graph_sampler import molecule_sampler
31

32
FLAGS = flags.FLAGS
33

34
flags.DEFINE_string('in_file', None, 'Input file path.')
35
flags.DEFINE_string('out_file', None, 'Output file path.')
36
flags.DEFINE_string('seed', None, 'Seed used for random number generation.')
37

38

39
def main(argv):
40
  if len(argv) > 1:
41
    raise RuntimeError(f'Unexpected arguments: {argv[1:]}')
42

43
  input_stats = graph_io.get_stats(FLAGS.in_file)
44
  max_importance = input_stats['max_final_importance']
45
  with open(FLAGS.in_file) as input_file:
46
    rejector = molecule_sampler.RejectToUniform(
47
        base_iter=graph_io.graph_reader(input_file),
48
        max_importance=max_importance,
49
        rng_seed=FLAGS.seed)
50
    with open(FLAGS.out_file, 'w') as output_file:
51
      for graph in rejector:
52
        graph_io.write_graph(graph, output_file)
53
        if rejector.num_accepted % 10000 == 0:
54
          acc = rejector.num_accepted
55
          proc = rejector.num_processed
56
          print(f'Accepted {acc}/{proc}: {acc / proc * 100:.2f}%')
57

58
      output_stats = dict(
59
          num_samples=rejector.num_accepted,
60
          estimated_num_graphs=input_stats['estimated_num_graphs'],
61
          rng_seed=rejector.rng_seed)
62
      graph_io.write_stats(output_stats, output_file)
63

64
  acc = rejector.num_accepted
65
  proc = rejector.num_processed
66
  print(f'Done rejecting to uniform! Accepted {acc}/{proc}: '
67
        f'{acc / proc * 100:.2f}%')
68

69

70
if __name__ == '__main__':
71
  flags.mark_flags_as_required(['in_file', 'out_file'])
72
  app.run(main)
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.