google-research
82 строки · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Simple utilities."""
17
18import bz2
19import gzip
20import logging
21import re
22
23import numpy as np
24import pandas as pd
25
26# Column ID of perplexity in the tsv file.
27_PERPLEXITY_COLUMN_ID = 2
28
29# Regular expression for matching the ngram order.
30_ORDER_REGEX = r"\d+gram"
31
32
33def ngram_order_from_filename(filename):
34"""Returns n-gram order from a file name."""
35orders = re.findall(_ORDER_REGEX, filename)
36if len(orders) != 1:
37raise ValueError(f"Invalid filename {filename}")
38order = orders[0][0:orders[0].find("gram")]
39return int(order)
40
41
42def open_file(filename, mode="r", encoding="utf-8"):
43"""Open files of several types, with text mode of compressed files.
44
45Args:
46filename: File path to the file which need to be open.
47mode: Open mode, "r" for read and "w" for write.
48encoding: Encoding method for the content of the file.
49
50Returns:
51The opened file handle of the input filename.
52"""
53if filename.endswith(".gz"):
54# The "t" is appended for text mode.
55return gzip.open(filename, mode + "t")
56elif filename.endswith(".bz2"):
57return bz2.open(filename, mode + "t", encoding=encoding)
58else:
59return open(filename, mode, encoding=encoding)
60
61
62def ppl_to_entropy(ppl):
63"""Converts the perplexity to entropy (bits per character)."""
64return np.log10(ppl) / np.log10(2.0)
65
66
67def read_metrics(file_path):
68"""Reads metrics provided in a tsv file into pandas dataframe."""
69logging.info(f"Reading metrics from {file_path} ...")
70df = pd.read_csv(file_path, sep="\t", header=None)
71logging.info(f"Read {df.shape[0]} samples")
72return df
73
74
75def read_entropies(file_path, as_ppl=False):
76"""Reads entropies (or perplexities) from the pandas data frame."""
77df = read_metrics(file_path)
78ppl = df[_PERPLEXITY_COLUMN_ID]
79if as_ppl:
80return ppl
81else:
82return ppl_to_entropy(ppl)
83