google-research

Форк
0
138 строк · 4.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility for plotting character cooccurrence matrices.
17

18
List of available seaborn palettes:
19
-----------------------------------
20
Also see: https://seaborn.pydata.org/tutorial/color_palettes.html
21
"""
22

23
import logging
24

25
from typing import Sequence
26

27
from absl import app
28
from absl import flags
29

30
import matplotlib.pyplot as plt
31
import numpy as np
32
import pandas as pd
33
import seaborn as sns
34

35
flags.DEFINE_string(
36
    "input_file", None,
37
    ("Path to the text file containing the correlations. The file is produced "
38
     "by the `penn_choma` tool."))
39

40
flags.DEFINE_string(
41
    "output_file", None,
42
    ("Output file for the resulting figure (based on the extension: PDF, PNG, "
43
     "etc.)."))
44

45
flags.DEFINE_string(
46
    "color_map", "vlag",
47
    ("Color map to use. Some examples: \"PuBu\", \"vlag\", "
48
     "\"viridis\", \"mako\"."))
49

50
flags.DEFINE_integer(
51
    "dpi", 700, "Dots per pixel (DPI) resolution for the final images.")
52

53
FLAGS = flags.FLAGS
54

55
_PAIR_PREFIX = "1:"
56

57

58
def _plot(df):
59
  """Plots the heatmap defined in the dataframe."""
60
  plt.rcParams["font.family"] = "sans-serif"
61
  plt.rcParams["font.sans-serif"] = [
62
      "Noto Sans CJK SC", "Noto Sans CJK KR", "sans-serif"]
63
  logging.info("Saving PDF to %s ...", FLAGS.output_file)
64
  fig = plt.figure()
65
  sns.set(font_scale=0.7)
66
  ax = sns.heatmap(df, fmt=".2g", cmap=FLAGS.color_map,
67
                   cbar_kws=dict(use_gridspec=False,
68
                                 orientation="horizontal",
69
                                 pad=0.02,
70
                                 shrink=0.60),
71
                   square=True, vmin=-0.3, vmax=1.0, center=0.4,
72
                   yticklabels=False, xticklabels=False)
73
  ax.set_aspect("equal")
74
  fig.savefig(FLAGS.output_file, bbox_inches="tight",
75
              pad_inches=0, dpi=FLAGS.dpi)
76

77

78
def _process_input_file():
79
  """Processes input file and plots the results."""
80
  logging.info("Reading %s ...", FLAGS.input_file)
81

82
  # Parse the file into intermediate data structures.
83
  corr_pairs = []
84
  char_to_id = {}
85
  num_chars = 0
86
  with open(FLAGS.input_file, mode="r", encoding="utf8") as f:
87
    lines = [line.rstrip() for line in f.readlines()]
88
    for line in lines:
89
      if not line.startswith(_PAIR_PREFIX):
90
        continue
91
      toks = line.replace(_PAIR_PREFIX, "").split("\t")
92
      if len(toks) != 2:
93
        raise ValueError("Expected two tokens in {}".format(toks))
94
      char_toks = toks[0].split(",")
95
      if len(char_toks) != 2:
96
        raise ValueError("Expected two char tokens in {}".format(char_toks))
97

98
      if char_toks[0] not in char_to_id:
99
        from_char_id = num_chars
100
        char_to_id[char_toks[0]] = from_char_id
101
        num_chars += 1
102
      else:
103
        from_char_id = char_to_id[char_toks[0]]
104
      if char_toks[1] not in char_to_id:
105
        to_char_id = num_chars
106
        char_to_id[char_toks[1]] = to_char_id
107
        num_chars += 1
108
      else:
109
        to_char_id = char_to_id[char_toks[1]]
110
      corr_pairs.append(((from_char_id, to_char_id), float(toks[1])))
111
  sqrt_num_chars = int(np.sqrt(len(corr_pairs)))
112
  assert num_chars == sqrt_num_chars
113
  logging.info("Read %d correlation pairs, %d characters.", len(corr_pairs),
114
               num_chars)
115

116
  # Convert raw containers to Pandas data frame and plot.
117
  sorted_by_id = [char for char, _ in sorted(
118
      char_to_id.items(), key=lambda item: item[1])]
119
  corr = np.zeros((num_chars, num_chars))
120
  for chars, r in corr_pairs:
121
    corr[chars[0], chars[1]] = r
122
  corr_df = pd.DataFrame(corr, columns=sorted_by_id, index=sorted_by_id)
123
  corr_df.info()
124
  _plot(corr_df)
125

126

127
def main(argv):
128
  if len(argv) > 1:
129
    raise app.UsageError("Too many command-line arguments.")
130
  if not FLAGS.input_file:
131
    raise app.UsageError("Specify --input_file")
132
  if not FLAGS.output_file:
133
    raise app.UsageError("Specify --output_file")
134
  _process_input_file()
135

136

137
if __name__ == "__main__":
138
  app.run(main)
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.