google-research
396 строк · 11.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generic utilities.
17"""
18import contextlib
19import itertools
20import json
21import logging
22import operator
23import os
24import pathlib
25import shutil
26import tempfile
27import textwrap
28import time
29from typing import Any, Callable, Container, Generator, Iterable, Optional
30from typing import Type, Union
31
32from absl import flags
33import colorama
34import dataclasses
35import psutil
36
37# pylint: enable=g-import-not-at-top
38EXTERNAL = True
39# pylint: disable=g-import-not-at-top
40if EXTERNAL:
41from tensorflow.compat.v1 import gfile # pytype: disable=import-error
42# pylint: enable=g-import-not-at-top
43
44
45FLAGS = flags.FLAGS
46LOGGER = logging.getLogger(__name__)
47PathType = Union[pathlib.Path, str]
48
49
50@dataclasses.dataclass
51class TimeStamp:
52"""Simple dataclass to represent a timestamp."""
53hours: int
54minutes: int
55seconds: int
56milliseconds: int
57
58def format(self):
59return (f"{str(self.hours).zfill(2)}:{str(self.minutes).zfill(2)}:"
60f"{str(self.seconds).zfill(2)}.{str(self.milliseconds).zfill(3)}")
61
62@classmethod
63def from_seconds(cls, duration):
64hours = int(duration // 3600)
65duration %= 3600
66minutes = int(duration // 60)
67duration %= 60
68seconds = int(duration)
69duration %= 1
70milliseconds = int(1000 * (duration))
71return TimeStamp(hours, minutes, seconds, milliseconds)
72
73
74@contextlib.contextmanager
75def log_duration(logger, function_name, task_name,
76level = logging.DEBUG):
77"""With statement (context manager) to log the duration of a code block.
78
79Arguments:
80logger: The logger to use to do the logging.
81function_name: The function from which log_duration is called.
82task_name: A short description of the task being monitored.
83level: Logging level to use.
84Yields:
85None
86"""
87# Do this at the entry inside of the with statement
88logger.log(level, "(%(function_name)s): Starting task "
89"`%(color)s%(task_name)s%(reset)s`.",
90dict(function_name=function_name, task_name=task_name,
91color=colorama.Fore.CYAN, reset=colorama.Style.RESET_ALL))
92start = time.time()
93yield
94# Do this at the exit of the with statement
95duration = time.time() - start
96timestamp = TimeStamp.from_seconds(duration)
97logger.log(level,
98"(%(function_name)s): Done with task "
99"`%(cyan)s%(task_name)s%(style_reset)s`. "
100" Took %(color)s`%(ts)s`%(style_reset)s",
101dict(function_name=function_name, task_name=task_name,
102color=colorama.Fore.GREEN, ts=timestamp.format(),
103style_reset=colorama.Style.RESET_ALL,
104cyan=colorama.Fore.CYAN))
105
106
107def copy_to_tmp(in_file):
108"""Copies a file to a tempfile.
109
110The point of this is to copy small files from CNS to tempdirs on
111the client when using code that's that hasn't been Google-ified yet.
112Examples of files are the vocab and config files of the Hugging Face
113tokenizer.
114
115Arguments:
116in_file: Path to the object to be copied, likely in CNS
117Returns:
118Path where the object ended up (inside of the tempdir).
119"""
120# We just want to use Python's safe tempfile name generation algorithm
121with tempfile.NamedTemporaryFile(delete=False) as f_out:
122target_path = os.path.join(tempfile.gettempdir(), f_out.name)
123gfile.Copy(in_file, target_path, overwrite=True)
124return target_path
125
126
127def check_equal(a, b):
128"""Checks if two values are equal.
129
130Args:
131a: First value.
132b: Second value.
133
134Returns:
135Always returns `None`.
136
137Raises:
138RuntimeError: If the values aren't equal.
139"""
140check_operator(operator.eq, a, b)
141
142
143def check_contained(unit, container):
144check_operator(operator.contains, container, unit)
145
146
147def check_operator(op, a, b):
148"""Checks an operator with two arguments.
149
150Args:
151op: Comparison function.
152a: First value.
153b: Second value.
154
155Returns:
156Always returns `None`.
157
158Raises:
159RuntimeError: If the values aren't equal.
160"""
161if not op(a, b):
162raise RuntimeError("Operator test failed.\n"
163f"Operator: {op}\n"
164f"left arg: {a}\n"
165f"right arg: {b}")
166
167
168def check_isinstance(obj, type_):
169if not isinstance(obj, type_):
170raise RuntimeError("Failed isinstance check.\n"
171f"\tExpected: {type_}\n"
172f"\tGot: {type(obj)}")
173
174
175def check_exists(path):
176"""Check if a directory or a path is at the received path.
177
178Arguments:
179path: The path to check.
180Returns:
181Nothing.
182Raises:
183RuntimeError: Raised if nothing exists at the received path.
184"""
185if path is None:
186raise RuntimeError("Got None instead of a valid path.")
187
188if not gfile.Exists(path):
189raise RuntimeError(f"File path `{path}` doesn't exist.")
190
191
192def check_glob_prefix(prefix):
193"""Verifies that there is at least one match for a glob prefix.
194
195Args:
196prefix: Glob prefix to check.
197
198Returns:
199None
200
201Raises:
202RuntimeError: If there are no matches or the parent path doesn't exist.
203"""
204if prefix is None:
205raise RuntimeError("Got None instead of a valid glob prefix.")
206
207path = pathlib.Path(prefix)
208# Check if the prefix path FLAGS.source_embeddings_prefix has at least one
209# match. This methods stays fast even if there are a trillion matches.
210# Definitely unnecessary. (len(list(matches)) > 0 felt ugly.)
211if not gfile.Exists(path.parent):
212raise RuntimeError(f"The parent of the glob prefix didn't exist:\n"
213f" - Glob prefix: {path}\n"
214f" - Glob parent: {path.parent}")
215matches = path.parent.glob(path.name + "*")
216at_least_one = len(list(itertools.islice(matches, 0, 1))) > 0 # pylint: disable=g-explicit-length-test
217if not at_least_one:
218raise RuntimeError("No matches to the globbing prefix:\n{prefix}")
219
220
221def check_not_none(obj):
222if obj is None:
223raise RuntimeError("Object was None.")
224
225
226def from_json_file(path):
227"""Reads from a json file.
228
229Args:
230path: Path to read from.
231
232Returns:
233The object read from the json file.
234"""
235with gfile.GFile(str(path)) as fin:
236return json.loads(fin.read())
237
238
239def to_json_file(path, obj, indent = 4):
240"""Saves to a json file.
241
242Args:
243path: Where to save.
244obj: The object to save
245
246Returns:
247None
248"""
249with gfile.GFile(str(path), "w") as fout:
250fout.write(json.dumps(obj, indent=indent))
251
252
253def log_module_args(
254logger, module_name,
255level = logging.DEBUG, sort = True
256):
257"""Logs the list of flags defined in a module, as well as their value.
258
259Args:
260logger: Instance of the logger to use for logging.
261module_name: Name of the module from which to print the args.
262level: Logging level to use.
263sort: Whether to sort the flags
264
265Returns:
266None
267"""
268flags_ = FLAGS.flags_by_module_dict()[module_name]
269if sort:
270flags_.sort(key=lambda flag: flag.name)
271# `json.dumps` formats dicts in a nice way when indent is specified.
272content = "\n" + json.dumps({flag.name: flag.value for flag in flags_
273}, indent=4)
274if logger is not None:
275logger.log(level, content)
276return content
277
278
279def term_size(default_cols = 80):
280return shutil.get_terminal_size((default_cols, 20)).columns
281
282
283def wrap_iterable(
284iterable, numbers = False, length = None
285):
286"""Takes a number of long lines, and wraps them to the terminal length.
287
288Adds dashes by default, numbers the lines if numbers=True. The length defaults
289to the length of the terminal at the moment the function is called. Defaults
290to 80 wide if not currently in a terminal.
291
292Args:
293iterable: The object with the text instances.
294numbers: Whether to use line numbers.
295
296Returns:
297
298"""
299if length is None:
300# Can't set it as default as default value are evaluated at function
301# definition time.
302length = term_size(120)
303if numbers:
304wrapped = (textwrap.fill(str(line), length, initial_indent=f" {i} - ",
305subsequent_indent=" " * len(f" {i} - "))
306for i, line in enumerate(iterable))
307else:
308wrapped = (textwrap.fill(str(line), length, initial_indent=" - ",
309subsequent_indent=" ") for line in iterable)
310return "\n".join(wrapped)
311
312
313class MovingAverage:
314"""Creates a simple EMA (exponential moving average).
315"""
316
317def __init__(self, constant, settable_average = False):
318"""Creates the EMA object.
319
320Args:
321constant: update constant. The alpha in
322https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average
323"""
324constant = float(constant)
325check_operator(operator.lt, constant, 1)
326self._constant = constant
327self._average = None
328self.settable_average = settable_average
329
330def update(self, value):
331value = float(value)
332if self._average is None:
333self._average = value
334else:
335self._average = (self._constant * self._average
336+ (1 - self._constant) * value)
337
338@property
339def average(self):
340return self._average
341
342@average.setter
343def average(self, value):
344if self.settable_average:
345self._average = float(value)
346else:
347raise RuntimeError("The value of average should not be set this way")
348
349def __repr__(self):
350return f"<MovingAverage: self.average={self._average}>"
351
352def __str__(self):
353return str(self._average)
354
355
356class FlagChoices:
357"""Adds a .choices function with the choices for the Flag.
358
359Example:
360>>> class DirectionChoices(FlagChoices):
361>>> north = "north"
362>>> south = "south"
363>>> east = "east"
364>>> west = "west"
365>>> # ...
366>>> flags.DEFINE_enum("direction", DirectionChoices.north,
367>>> DirectionChoices.choices(), "In which direciton do"
368>>> " you want to go.")
369>>> # ...
370>>> # other case
371>>> if argument_value not in DirectionChoices:
372>>> raise ValueError(f"Value {} not in DirectionChoices:"
373>>> f"{DirectionChoices.choices}")
374
375"""
376
377@classmethod
378def choices(cls):
379if getattr(cls, "_choices", None) is None:
380cls._choices = frozenset([
381v for k, v in vars(cls).items()
382if k != "choices" and not k.startswith("_")
383])
384return cls._choices
385
386
387def print_mem(description, logger):
388"""Prints the current memory use of the main process."""
389process = psutil.Process(os.getpid())
390logger.debug(
391"MEM USAGE:\n"
392" - Usage: %(mem)f GB\n"
393" - Description: %(yellow)s%(description)s%(reset)s",
394dict(mem=process.memory_info().rss / 1E9, description=description,
395yellow=colorama.Fore.YELLOW, reset=colorama.Style.RESET_ALL
396))