google-research

Форк
0
/
describe_splits.py 
72 строки · 2.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Collects basic stats for training and test splits from the results file.
17

18
Example:
19
--------
20
  LANGUAGE=...
21
  cat data/ngrams/results/reading/00/baselines/${LANGUAGE}.*.tsv > /tmp/${LANGUAGE}.tsv
22
  python describe_splits.py \
23
    --results_tsv_file /tmp/${LANGUAGE}.tsv
24

25
Dependencies:
26
-------------
27
  absl
28
  pandas
29
  statsmodels
30
"""
31

32
from typing import Sequence
33

34
import logging
35

36
from absl import app
37
from absl import flags
38

39
import pandas as pd
40
import statsmodels.stats.api as sms
41

42
flags.DEFINE_string(
43
    "results_tsv_file", "",
44
    "Results text file in tab-separated (tsv) format.")
45

46
FLAGS = flags.FLAGS
47

48

49
def _to_str(stats):
50
  """Retrieves basic stats from the object."""
51
  return f"mean: {stats.mean} var: {stats.var} std: {stats.std}"
52

53

54
def main(argv):
55
  if len(argv) > 1:
56
    raise app.UsageError("Too many command-line arguments.")
57
  if not FLAGS.results_tsv_file:
58
    raise app.UsageError("Specify --results_tsv_file [FILE]!")
59

60
  logging.info(f"Reading metrics from {FLAGS.results_tsv_file} ...")
61
  df = pd.read_csv(FLAGS.results_tsv_file, sep="\t", header=None)
62
  logging.info(f"Read {df.shape[0]} samples")
63
  num_train_toks = list(df[0])  # Token can be char or word.
64
  train_stats = sms.DescrStatsW(num_train_toks)
65
  logging.info(f"Train stats: {_to_str(train_stats)}")
66
  num_test_toks = list(df[1])
67
  test_stats = sms.DescrStatsW(num_test_toks)
68
  logging.info(f"Test stats: {_to_str(test_stats)}")
69

70

71
if __name__ == "__main__":
72
  app.run(main)
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.