google-research
138 строк · 4.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility for plotting character cooccurrence matrices.
17
18List of available seaborn palettes:
19-----------------------------------
20Also see: https://seaborn.pydata.org/tutorial/color_palettes.html
21"""
22
23import logging
24
25from typing import Sequence
26
27from absl import app
28from absl import flags
29
30import matplotlib.pyplot as plt
31import numpy as np
32import pandas as pd
33import seaborn as sns
34
35flags.DEFINE_string(
36"input_file", None,
37("Path to the text file containing the correlations. The file is produced "
38"by the `penn_choma` tool."))
39
40flags.DEFINE_string(
41"output_file", None,
42("Output file for the resulting figure (based on the extension: PDF, PNG, "
43"etc.)."))
44
45flags.DEFINE_string(
46"color_map", "vlag",
47("Color map to use. Some examples: \"PuBu\", \"vlag\", "
48"\"viridis\", \"mako\"."))
49
50flags.DEFINE_integer(
51"dpi", 700, "Dots per pixel (DPI) resolution for the final images.")
52
53FLAGS = flags.FLAGS
54
55_PAIR_PREFIX = "1:"
56
57
58def _plot(df):
59"""Plots the heatmap defined in the dataframe."""
60plt.rcParams["font.family"] = "sans-serif"
61plt.rcParams["font.sans-serif"] = [
62"Noto Sans CJK SC", "Noto Sans CJK KR", "sans-serif"]
63logging.info("Saving PDF to %s ...", FLAGS.output_file)
64fig = plt.figure()
65sns.set(font_scale=0.7)
66ax = sns.heatmap(df, fmt=".2g", cmap=FLAGS.color_map,
67cbar_kws=dict(use_gridspec=False,
68orientation="horizontal",
69pad=0.02,
70shrink=0.60),
71square=True, vmin=-0.3, vmax=1.0, center=0.4,
72yticklabels=False, xticklabels=False)
73ax.set_aspect("equal")
74fig.savefig(FLAGS.output_file, bbox_inches="tight",
75pad_inches=0, dpi=FLAGS.dpi)
76
77
78def _process_input_file():
79"""Processes input file and plots the results."""
80logging.info("Reading %s ...", FLAGS.input_file)
81
82# Parse the file into intermediate data structures.
83corr_pairs = []
84char_to_id = {}
85num_chars = 0
86with open(FLAGS.input_file, mode="r", encoding="utf8") as f:
87lines = [line.rstrip() for line in f.readlines()]
88for line in lines:
89if not line.startswith(_PAIR_PREFIX):
90continue
91toks = line.replace(_PAIR_PREFIX, "").split("\t")
92if len(toks) != 2:
93raise ValueError("Expected two tokens in {}".format(toks))
94char_toks = toks[0].split(",")
95if len(char_toks) != 2:
96raise ValueError("Expected two char tokens in {}".format(char_toks))
97
98if char_toks[0] not in char_to_id:
99from_char_id = num_chars
100char_to_id[char_toks[0]] = from_char_id
101num_chars += 1
102else:
103from_char_id = char_to_id[char_toks[0]]
104if char_toks[1] not in char_to_id:
105to_char_id = num_chars
106char_to_id[char_toks[1]] = to_char_id
107num_chars += 1
108else:
109to_char_id = char_to_id[char_toks[1]]
110corr_pairs.append(((from_char_id, to_char_id), float(toks[1])))
111sqrt_num_chars = int(np.sqrt(len(corr_pairs)))
112assert num_chars == sqrt_num_chars
113logging.info("Read %d correlation pairs, %d characters.", len(corr_pairs),
114num_chars)
115
116# Convert raw containers to Pandas data frame and plot.
117sorted_by_id = [char for char, _ in sorted(
118char_to_id.items(), key=lambda item: item[1])]
119corr = np.zeros((num_chars, num_chars))
120for chars, r in corr_pairs:
121corr[chars[0], chars[1]] = r
122corr_df = pd.DataFrame(corr, columns=sorted_by_id, index=sorted_by_id)
123corr_df.info()
124_plot(corr_df)
125
126
127def main(argv):
128if len(argv) > 1:
129raise app.UsageError("Too many command-line arguments.")
130if not FLAGS.input_file:
131raise app.UsageError("Specify --input_file")
132if not FLAGS.output_file:
133raise app.UsageError("Specify --output_file")
134_process_input_file()
135
136
137if __name__ == "__main__":
138app.run(main)
139