google-research

Форк
0
396 строк · 11.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generic utilities.
17
"""
18
import contextlib
19
import itertools
20
import json
21
import logging
22
import operator
23
import os
24
import pathlib
25
import shutil
26
import tempfile
27
import textwrap
28
import time
29
from typing import Any, Callable, Container, Generator, Iterable, Optional
30
from typing import Type, Union
31

32
from absl import flags
33
import colorama
34
import dataclasses
35
import psutil
36

37
# pylint: enable=g-import-not-at-top
38
EXTERNAL = True
39
# pylint: disable=g-import-not-at-top
40
if EXTERNAL:
41
  from tensorflow.compat.v1 import gfile  # pytype: disable=import-error
42
# pylint: enable=g-import-not-at-top
43

44

45
FLAGS = flags.FLAGS
46
LOGGER = logging.getLogger(__name__)
47
PathType = Union[pathlib.Path, str]
48

49

50
@dataclasses.dataclass
51
class TimeStamp:
52
  """Simple dataclass to represent a timestamp."""
53
  hours: int
54
  minutes: int
55
  seconds: int
56
  milliseconds: int
57

58
  def format(self):
59
    return (f"{str(self.hours).zfill(2)}:{str(self.minutes).zfill(2)}:"
60
            f"{str(self.seconds).zfill(2)}.{str(self.milliseconds).zfill(3)}")
61

62
  @classmethod
63
  def from_seconds(cls, duration):
64
    hours = int(duration // 3600)
65
    duration %= 3600
66
    minutes = int(duration // 60)
67
    duration %= 60
68
    seconds = int(duration)
69
    duration %= 1
70
    milliseconds = int(1000 * (duration))
71
    return TimeStamp(hours, minutes, seconds, milliseconds)
72

73

74
@contextlib.contextmanager
75
def log_duration(logger, function_name, task_name,
76
                 level = logging.DEBUG):
77
  """With statement (context manager) to log the duration of a code block.
78

79
  Arguments:
80
    logger: The logger to use to do the logging.
81
    function_name: The function from which log_duration is called.
82
    task_name: A short description of the task being monitored.
83
    level: Logging level to use.
84
  Yields:
85
    None
86
  """
87
  # Do this at the entry inside of the with statement
88
  logger.log(level, "(%(function_name)s): Starting task  "
89
                    "`%(color)s%(task_name)s%(reset)s`.",
90
             dict(function_name=function_name, task_name=task_name,
91
                  color=colorama.Fore.CYAN, reset=colorama.Style.RESET_ALL))
92
  start = time.time()
93
  yield
94
  # Do this at the exit of the with statement
95
  duration = time.time() - start
96
  timestamp = TimeStamp.from_seconds(duration)
97
  logger.log(level,
98
             "(%(function_name)s): Done with task "
99
             "`%(cyan)s%(task_name)s%(style_reset)s`. "
100
             " Took %(color)s`%(ts)s`%(style_reset)s",
101
             dict(function_name=function_name, task_name=task_name,
102
                  color=colorama.Fore.GREEN, ts=timestamp.format(),
103
                  style_reset=colorama.Style.RESET_ALL,
104
                  cyan=colorama.Fore.CYAN))
105

106

107
def copy_to_tmp(in_file):
108
  """Copies a file to a tempfile.
109

110
  The point of this is to copy small files from CNS to tempdirs on
111
  the client when using code that's that hasn't been Google-ified yet.
112
  Examples of files are the vocab and config files of the Hugging Face
113
  tokenizer.
114

115
  Arguments:
116
    in_file: Path to the object to be copied, likely in CNS
117
  Returns:
118
    Path where the object ended up (inside of the tempdir).
119
  """
120
  # We just want to use Python's safe tempfile name generation algorithm
121
  with tempfile.NamedTemporaryFile(delete=False) as f_out:
122
    target_path = os.path.join(tempfile.gettempdir(), f_out.name)
123
  gfile.Copy(in_file, target_path, overwrite=True)
124
  return target_path
125

126

127
def check_equal(a, b):
128
  """Checks if two values are equal.
129

130
  Args:
131
    a: First value.
132
    b: Second value.
133

134
  Returns:
135
    Always returns `None`.
136

137
  Raises:
138
    RuntimeError: If the values aren't equal.
139
  """
140
  check_operator(operator.eq, a, b)
141

142

143
def check_contained(unit, container):
144
  check_operator(operator.contains, container, unit)
145

146

147
def check_operator(op, a, b):
148
  """Checks an operator with two arguments.
149

150
  Args:
151
    op: Comparison function.
152
    a: First value.
153
    b: Second value.
154

155
  Returns:
156
    Always returns `None`.
157

158
  Raises:
159
    RuntimeError: If the values aren't equal.
160
  """
161
  if not op(a, b):
162
    raise RuntimeError("Operator test failed.\n"
163
                       f"Operator:    {op}\n"
164
                       f"left arg:    {a}\n"
165
                       f"right arg:   {b}")
166

167

168
def check_isinstance(obj, type_):
169
  if not isinstance(obj, type_):
170
    raise RuntimeError("Failed isinstance check.\n"
171
                       f"\tExpected: {type_}\n"
172
                       f"\tGot:      {type(obj)}")
173

174

175
def check_exists(path):
176
  """Check if a directory or a path is at the received path.
177

178
  Arguments:
179
    path: The path to check.
180
  Returns:
181
    Nothing.
182
  Raises:
183
    RuntimeError: Raised if nothing exists at the received path.
184
  """
185
  if path is None:
186
    raise RuntimeError("Got None instead of a valid path.")
187

188
  if not gfile.Exists(path):
189
    raise RuntimeError(f"File path `{path}` doesn't exist.")
190

191

192
def check_glob_prefix(prefix):
193
  """Verifies that there is at least one match for a glob prefix.
194

195
  Args:
196
    prefix: Glob prefix to check.
197

198
  Returns:
199
    None
200

201
  Raises:
202
    RuntimeError: If there are no matches or the parent path doesn't exist.
203
  """
204
  if prefix is None:
205
    raise RuntimeError("Got None instead of a valid glob prefix.")
206

207
  path = pathlib.Path(prefix)
208
  # Check if the prefix path FLAGS.source_embeddings_prefix has at least one
209
  # match. This methods stays fast even if there are a trillion matches.
210
  # Definitely unnecessary. (len(list(matches)) > 0 felt ugly.)
211
  if not gfile.Exists(path.parent):
212
    raise RuntimeError(f"The parent of the glob prefix didn't exist:\n"
213
                       f" - Glob prefix: {path}\n"
214
                       f" - Glob parent: {path.parent}")
215
  matches = path.parent.glob(path.name + "*")
216
  at_least_one = len(list(itertools.islice(matches, 0, 1))) > 0  # pylint: disable=g-explicit-length-test
217
  if not at_least_one:
218
    raise RuntimeError("No matches to the globbing prefix:\n{prefix}")
219

220

221
def check_not_none(obj):
222
  if obj is None:
223
    raise RuntimeError("Object was None.")
224

225

226
def from_json_file(path):
227
  """Reads from a json file.
228

229
  Args:
230
    path: Path to read from.
231

232
  Returns:
233
    The object read from the json file.
234
  """
235
  with gfile.GFile(str(path)) as fin:
236
    return json.loads(fin.read())
237

238

239
def to_json_file(path, obj, indent = 4):
240
  """Saves to a json file.
241

242
  Args:
243
    path: Where to save.
244
    obj: The object to save
245

246
  Returns:
247
    None
248
  """
249
  with gfile.GFile(str(path), "w") as fout:
250
    fout.write(json.dumps(obj, indent=indent))
251

252

253
def log_module_args(
254
    logger, module_name,
255
    level = logging.DEBUG, sort = True
256
):
257
  """Logs the list of flags defined in a module, as well as their value.
258

259
  Args:
260
    logger: Instance of the logger to use for logging.
261
    module_name: Name of the module from which to print the args.
262
    level: Logging level to use.
263
    sort: Whether to sort the flags
264

265
  Returns:
266
    None
267
  """
268
  flags_ = FLAGS.flags_by_module_dict()[module_name]
269
  if sort:
270
    flags_.sort(key=lambda flag: flag.name)
271
  # `json.dumps` formats dicts in a nice way when indent is specified.
272
  content = "\n" + json.dumps({flag.name: flag.value for flag in flags_
273
                               }, indent=4)
274
  if logger is not None:
275
    logger.log(level, content)
276
  return content
277

278

279
def term_size(default_cols = 80):
280
  return shutil.get_terminal_size((default_cols, 20)).columns
281

282

283
def wrap_iterable(
284
    iterable, numbers = False, length = None
285
):
286
  """Takes a number of long lines, and wraps them to the terminal length.
287

288
  Adds dashes by default, numbers the lines if numbers=True. The length defaults
289
  to the length of the terminal at the moment the function is called. Defaults
290
  to 80 wide if not currently in a terminal.
291

292
  Args:
293
    iterable: The object with the text instances.
294
    numbers: Whether to use line numbers.
295

296
  Returns:
297

298
  """
299
  if length is None:
300
    # Can't set it as default as default value are evaluated at function
301
    # definition time.
302
    length = term_size(120)
303
  if numbers:
304
    wrapped = (textwrap.fill(str(line), length, initial_indent=f" {i} - ",
305
                             subsequent_indent=" " * len(f" {i} - "))
306
               for i, line in enumerate(iterable))
307
  else:
308
    wrapped = (textwrap.fill(str(line), length, initial_indent=" - ",
309
                             subsequent_indent="   ") for line in iterable)
310
  return "\n".join(wrapped)
311

312

313
class MovingAverage:
314
  """Creates a simple EMA (exponential moving average).
315
  """
316

317
  def __init__(self, constant, settable_average = False):
318
    """Creates the EMA object.
319

320
    Args:
321
      constant: update constant. The alpha in
322
        https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average
323
    """
324
    constant = float(constant)
325
    check_operator(operator.lt, constant, 1)
326
    self._constant = constant
327
    self._average = None
328
    self.settable_average = settable_average
329

330
  def update(self, value):
331
    value = float(value)
332
    if self._average is None:
333
      self._average = value
334
    else:
335
      self._average = (self._constant * self._average
336
                       + (1 - self._constant) * value)
337

338
  @property
339
  def average(self):
340
    return self._average
341

342
  @average.setter
343
  def average(self, value):
344
    if self.settable_average:
345
      self._average = float(value)
346
    else:
347
      raise RuntimeError("The value of average should not be set this way")
348

349
  def __repr__(self):
350
    return f"<MovingAverage: self.average={self._average}>"
351

352
  def __str__(self):
353
    return str(self._average)
354

355

356
class FlagChoices:
357
  """Adds a .choices function with the choices for the Flag.
358

359
  Example:
360
    >>> class DirectionChoices(FlagChoices):
361
    >>>     north = "north"
362
    >>>     south = "south"
363
    >>>     east = "east"
364
    >>>     west = "west"
365
    >>> # ...
366
    >>> flags.DEFINE_enum("direction", DirectionChoices.north,
367
    >>>                    DirectionChoices.choices(), "In which direciton do"
368
    >>>                                                " you want to go.")
369
    >>> # ...
370
    >>> # other case
371
    >>> if argument_value not in DirectionChoices:
372
    >>>     raise ValueError(f"Value {} not in DirectionChoices:"
373
    >>>                      f"{DirectionChoices.choices}")
374

375
  """
376

377
  @classmethod
378
  def choices(cls):
379
    if getattr(cls, "_choices", None) is None:
380
      cls._choices = frozenset([
381
          v for k, v in vars(cls).items()
382
          if k != "choices" and not k.startswith("_")
383
      ])
384
    return cls._choices
385

386

387
def print_mem(description, logger):
388
  """Prints the current memory use of the main process."""
389
  process = psutil.Process(os.getpid())
390
  logger.debug(
391
      "MEM USAGE:\n"
392
      " - Usage: %(mem)f GB\n"
393
      " - Description: %(yellow)s%(description)s%(reset)s",
394
      dict(mem=process.memory_info().rss / 1E9, description=description,
395
           yellow=colorama.Fore.YELLOW, reset=colorama.Style.RESET_ALL
396
           ))

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.