google-research

Форк
0
716 строк · 26.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""A generic serializable hyperparameter container.
17

18
hparam is based on `attr` under the covers, but provides additional features,
19
such as serialization and deserialization in a format that is compatible with
20
(now defunct) tensorflow.HParams, runtime type checking, implicit casting where
21
safe to do so (e.g. int->float, scalar->list).
22

23
Unlike tensorflow.HParams, this supports hierarchical nesting of parameters for
24
better organization, aliasing parameters to short abbreviations for compact
25
serialization while maintaining code readability, and support for Enum values.
26

27
Example usage:
28
  @hparam.s
29
  class MyNestedHParams:
30
    learning_rate: float = hparam.field(abbrev='lr', default=0.1)
31
    layer_sizes: List[int] = hparam.field(abbrev='ls', default=[256, 64, 32])
32

33
  @hparam.s
34
  class MyHParams:
35
    nested_params: MyNestedHParams = hparam.nest(MyNestedHParams)
36
    non_nested_param: int = hparam.field(abbrev='nn', default=0)
37

38
  hparams = MyHParams(nested_params=MyNestedHParams(
39
                        learning_rate=0.02, layer_sizes=[100, 10]),
40
                      non_nested_param=5)
41
  hparams.nested_params.learning_rate = 0.002
42
  serialized = hparams.serialize()  # "lr=0.002,ls=[100,10],nn=5"
43
  hparams.nested_params.learning_rate = 0.003
44
  new_hparams = MyHParams(serialized)
45
  new_hparams.nested_params.learning_rate == 0.002 # True
46
"""
47

48
import collections
49
import copy
50
import csv
51
import enum
52
import inspect
53
import json
54
import numbers
55
import re
56
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
57

58
import attr
59
import six
60

61
_ABBREV_KEY = 'hparam.abbrev'
62
_SCALAR_TYPE_KEY = 'hparam.scalar_type'
63
_IS_LIST_KEY = 'hparam.is_list'
64
_PREFIX_KEY = 'hparam.prefix'
65
_SERIALIZED_ARG = '_hparam_serialized_arg'
66

67
# Define the regular expression for parsing a single clause of the input
68
# (delimited by commas).  A legal clause looks like:
69
#   <variable name> = <rhs>
70
# where <rhs> is either a single token or [] enclosed list of tokens.
71
# For example:  "var = a" or "x = [1,2,3]"
72
_PARAM_RE = re.compile(
73
    r"""
74
  (?P<name>[a-zA-Z][\w\.]*)      # variable name: 'var' or 'x'
75
  \s*=\s*
76
  ((?P<strval>".*")              # single quoted string value: '"a,b=c"' or None
77
   |
78
   (?P<val>[^,\[]*)              # single value: 'a' or None
79
   |
80
   \[(?P<vals>[^\]]*)\])         # list of values: None or '1,2,3'
81
                                 # (the regex removes the surrounding brackets)
82
  ($|,\s*)""", re.VERBOSE)
83

84
_ValidScalarInstanceType = Union[int, float, str, enum.Enum]
85
_ValidListInstanceType = Union[List[_ValidScalarInstanceType],
86
                               Tuple[_ValidScalarInstanceType]]
87
_ValidFieldInstanceType = Union[_ValidScalarInstanceType,
88
                                _ValidListInstanceType]
89
_ValidScalarType = Type[_ValidScalarInstanceType]
90
T = TypeVar('T')
91

92

93
@attr.s
94
class _FieldInfo:
95
  """Metadata for a single HParam field."""
96
  # The path to a field from the root hparam.s instance. This is to enable
97
  # finding fields that are in nested hparam.s classes. For example, a path
98
  # ['foo', 'bar'] means that the root class has a field called 'foo' created
99
  # with hparam.nest(), whose value is another hparam.s class, which contains a
100
  # field, 'bar' that was created using hparam.field(). All path elements other
101
  # than the last correspond to nested classes, while the last is a field.
102
  path: List[str] = attr.ib()
103
  # If the field value is a single (scalar) value, then this is the type of the
104
  # value. If it is a list then this is the type of the list elements, which
105
  # must all be the same.
106
  scalar_type: _ValidScalarType = attr.ib()
107
  # Whether this field is a list type.
108
  is_list: bool = attr.ib()
109
  # The default value for this field.
110
  default_value: _ValidFieldInstanceType = attr.ib()
111

112

113
_HParamsMapType = Dict[str, _FieldInfo]
114

115

116
def _get_type(instance):
117
  """Determines the type of a value.
118

119
  Both whether it is an iterable and what the scalar type is. For iterables,
120
  scalar type is the type of its elements, which are required to all be the same
121
  type.
122

123
  Valid types are int, float, string, Enum (where the value is int, float, or
124
  string), and lists or tuples of those types.
125

126
  Args:
127
    instance: The value whose type is to be determined.
128

129
  Returns:
130
    scalar_type: The type of `instance`'s elements, if `instance` is a list or
131
      tuple. Otherwise, the type of `instance`.
132
    is_list: Whether `instance` is a list or tuple.
133

134
  Raises:
135
    TypeError: If instance is not a valid type.
136
      * `instance` is an iterable that is not a list, tuple or string.
137
      * `instance` is a list or tuple with elements of different types.
138
      * `instance` is an empty list or tuple.
139
      * `instance` is a list or tuple whose elements are non-string iterables.
140
      * `instance` is None or a list or tuple of Nones.
141
      * `instance` is none of {int, float, str, Enum, list, tuple}.
142
      * `instance` is a list or tuple whose values are not one of the above
143
         types.
144
      * `instance` is an Enum type whose values are neither numbers nor strings.
145
  """
146

147
  is_list = False
148
  scalar_type = type(instance)
149
  if isinstance(instance, Iterable) and not issubclass(
150
      scalar_type, (six.string_types, six.binary_type)):
151
    is_list = True
152
    if not instance:
153
      raise TypeError('Empty iterables cannot be used as default values.')
154

155
    if not isinstance(instance, collections.abc.Sequence):
156
      # Most likely a Dictionary or Set.
157
      raise TypeError('Only numbers, strings, and lists are supported. Found '
158
                      f'{scalar_type}.')
159

160
    scalar_type = type(instance[0])
161
    if isinstance(instance[0], Iterable) and not issubclass(
162
        scalar_type, (six.string_types, six.binary_type)):
163
      raise ValueError('Nested iterables and dictionaries are not supported.')
164
    if not all([isinstance(i, scalar_type) for i in instance]):
165
      raise TypeError('Iterables of mixed type are not supported.')
166

167
  if issubclass(scalar_type, type(None)):
168
    raise TypeError('Fields cannot have a default value of None.')
169

170
  valid_field_types = (six.string_types, six.binary_type, numbers.Integral,
171
                       numbers.Number)
172
  if not issubclass(scalar_type, valid_field_types + (enum.Enum,)):
173
    raise TypeError(
174
        'Supported types include: number, string, Enum, and lists of those '
175
        f'types. {scalar_type} is not one of those.')
176

177
  if issubclass(scalar_type, enum.Enum):
178
    enum_value_type = type(list(scalar_type.__members__.values())[0].value)
179
    if not issubclass(enum_value_type, valid_field_types):
180
      raise TypeError(f'Enum type {scalar_type} has values of type '
181
                      f'{enum_value_type}, which is not allowed. Enum values '
182
                      'must be numbers or strings.')
183

184
  return (scalar_type, is_list)
185

186

187
def _make_converter(scalar_type,
188
                    is_list):
189
  """Produces a function that casts a value to the target type, if compatible.
190

191
  Args:
192
    scalar_type: The scalar type of the hparam.
193
    is_list: Whether the hparam is a list type.
194

195
  Returns:
196
    A function that casts its input to either `scalar_type` or
197
    `List[scalar_type]` depending on `is_list`.
198
  """
199

200
  def scalar_converter(value):
201
    """Converts a scalar value to the target type, if compatible.
202

203
    Args:
204
      value: The value to be converted.
205

206
    Returns:
207
      The converted value.
208

209
    Raises:
210
      TypeError: If the type of `value` is not compatible with scalar_type.
211
        * If `scalar_type` is a string type, but `value` is not.
212
        * If `scalar_type` is a boolean, but `value` is not, or vice versa.
213
        * If `scalar_type` is an integer type, but `value` is not.
214
        * If `scalar_type` is a float type, but `value` is not a numeric type.
215
    """
216
    if (isinstance(value, Iterable) and
217
        not issubclass(type(value), (six.string_types, six.binary_type))):
218
      raise TypeError('Nested iterables are not supported')
219

220
    # If `value` is already of type `scalar_type`, return it directly.
221
    # `isinstance` is too weak (e.g. isinstance(True, int) == True).
222
    if type(value) == scalar_type:  # pylint: disable=unidiomatic-typecheck
223
      return value
224

225
    # Some callers use None, for which we can't do any casting/checking. :(
226
    if issubclass(scalar_type, type(None)):
227
      return value
228

229
    # Avoid converting a non-string type to a string.
230
    if (issubclass(scalar_type, (six.string_types, six.binary_type)) and
231
        not isinstance(value, (six.string_types, six.binary_type))):
232
      raise TypeError(
233
          f'Expected a string value but found {value} with type {type(value)}.')
234

235
    # Avoid converting a number or string type to a boolean or vice versa.
236
    if issubclass(scalar_type, bool) != isinstance(value, bool):
237
      raise TypeError(
238
          f'Expected a bool value but found {value} with type {type(value)}.')
239

240
    # Avoid converting float to an integer (the reverse is fine).
241
    if (issubclass(scalar_type, numbers.Integral) and
242
        not isinstance(value, numbers.Integral)):
243
      raise TypeError(
244
          f'Expected an integer value, but found {value} with type '
245
          f'{type(value)}.'
246
      )
247

248
    # Avoid converting a non-numeric type to a numeric type.
249
    if (issubclass(scalar_type, numbers.Number) and
250
        not isinstance(value, numbers.Number)):
251
      raise TypeError(
252
          f'Expected a numeric type, but found {value} with type {type(value)}.'
253
      )
254

255
    return scalar_type(value)
256

257
  def converter(value):
258
    """Converts a value to the target type, if compatible.
259

260
    Args:
261
      value: The value to be converted.
262

263
    Returns:
264
      The converted value.
265

266
    Raises:
267
      TypeError: If the type of `value` is not compatible with `scalar_type` and
268
      `is_list`.
269
        * If `scalar_type` is a string type, but `value` is not.
270
        * If `scalar_type` is a boolean, but `value` is not, or vice versa.
271
        * If `scalar_type` is an integer type, but `value` is not.
272
        * If `scalar_type` is a float type, but `value` is not a numeric type.
273
        * If `is_list` is False, but value is a non-string iterable.
274
    """
275
    value_is_listlike = (
276
        isinstance(value, Iterable) and
277
        not issubclass(type(value), (six.string_types, six.binary_type)))
278
    if value_is_listlike:
279
      if is_list:
280
        return [scalar_converter(v) for v in value]
281
      else:
282
        raise TypeError('Assigning an iterable to a scalar field.')
283
    else:
284
      if is_list:
285
        return [scalar_converter(value)]
286
      else:
287
        return scalar_converter(value)
288

289
  return converter
290

291

292
def field(abbrev, default):
293
  """Create a new field on an HParams class.
294

295
  A field is a single hyperparameter with a value. Fields must have an
296
  abbreviation key, which by convention is a short string, which is used to
297
  produce a concise serialization. Fields must also have a default value that
298
  determines the hyperparameter's type, which cannot be dynamically changed.
299
  Valid types are integers, floats, strings, enums that have values that are
300
  those types, and lists of those types.
301

302
  An HParams class can have child HParams classes, but those should be added
303
  using `nest` instead of `field`.
304

305
  Example usage:
306
    @hparam.s
307
    class MyHparams:
308
      learning_rate: float = hparam.field(abbrev='lr', default=0.1)
309
      layer_sizes: List[int] = hparam.field(abbrev='ls', default=[256, 64, 32])
310
      optimizer: OptimizerEnum = hparam.field(abbrev='opt',
311
      default=OptimizerEnum.SGD)
312

313
  Args:
314
    abbrev: A short string that represents this hyperparameter in the serialized
315
      format.
316
    default: The default value of this hyperparameter. This is required. Valid
317
      types are integers, floats, strings, enums that have values that are those
318
      types, and lists of those types. Default values for list-typed fields must
319
      be non-empty lists. None is not an allowed default. List values will be
320
      copied into instances of the field, so modifications to a list provided as
321
      default will not be reflected in existing or subsequently created class
322
      instances.
323

324
  Returns:
325
    A field-descriptor which can be consumed by a class decorated by @hparam.s.
326

327
  Raises:
328
    TypeError if the default value is not one of the allowed types.
329
  """
330

331
  scalar_type, is_list = _get_type(default)
332
  kwargs = {
333
      'kw_only': True,
334
      'metadata': {
335
          _ABBREV_KEY: abbrev,
336
          _SCALAR_TYPE_KEY: scalar_type,
337
          _IS_LIST_KEY: is_list,
338
      },
339
      'converter': _make_converter(scalar_type, is_list),
340
  }
341
  if is_list:
342
    # Lists are mutable, so we generate a factory method to produce a copy of
343
    # the list to avoid different instances of the class mutating each other.
344
    kwargs['factory'] = lambda: copy.copy(default)
345
  else:
346
    kwargs['default'] = default
347
  return attr.ib(**kwargs)  # pytype: disable=duplicate-keyword-argument
348

349

350
def nest(nested_class,
351
         prefix = None):
352
  """Create a nested HParams class field on a parent HParams class.
353

354
  An HParams class (a class decorated with @hparam.s) can have a field that is
355
  another HParams class, to create a hierarchical structure. Use `nest` to
356
  create these fields.
357

358
  Example usage:
359
    @hparam.s
360
    class MyNestedHParams:
361
      learning_rate: float = hparam.field(abbrev='lr', default=0.1)
362
      layer_sizes: List[int] = hparam.field(abbrev='ls', default=[256, 64, 32])
363

364
    @hparam.s
365
    class MyHParams:
366
      nested_params: MyNestedHParams = hparam.nest(MyNestedHParams)
367
      non_nested_param: int = hparam.field(abbrev='nn', default=0)
368

369
  Args:
370
    nested_class: The class of the nested hyperparams. The class must be
371
      decorated with @hparam.s.
372
    prefix: An optional prefix to add to the abbrev field of all fields in the
373
      nested hyperparams. This enables nesting the same class multiple times, as
374
      long as the prefix is different.
375

376
  Returns:
377
    A field-descriptor which can be consumed by a class decorated by @hparam.s.
378

379
  Raises:
380
    TypeError if `nested_class` is not decorated with @hparam.s.
381
  """
382
  if not inspect.isclass(nested_class):
383
    raise TypeError('nest() must be passed a class, not an instance.')
384
  if not (attr.has(nested_class) and
385
          getattr(nested_class, '__hparams_class__', False)):
386
    raise TypeError('Nested hparams classes must use the @hparam.s decorator')
387
  return attr.ib(
388
      factory=nested_class, kw_only=True, metadata={_PREFIX_KEY: prefix})
389

390

391
def _serialize_value(value,
392
                     field_info):
393
  """Serializes a value to a string.
394

395
  Lists are serialized by recursively calling this function on each of their
396
  elements. Enums use the enum value. Bools are cast to int. Strings that
397
  contain any of {,=[]"} are surrounded by double quotes. Everything is then
398
  cast using str().
399

400
  Args:
401
    value: The value to be serialized.
402
    field_info: The field info corresponding to `value`.
403

404
  Returns:
405
    The serialized value.
406
  """
407
  if field_info.is_list:
408
    list_value = value  # type: _ValidListInstanceType  # pytype: disable=annotation-type-mismatch
409
    modified_field_info = copy.copy(field_info)
410
    modified_field_info.is_list = False
411
    # Manually string-ify the list, since default str(list) adds whitespace.
412
    return ('[' + ','.join(
413
        [str(_serialize_value(v, modified_field_info)) for v in list_value]) +
414
            ']')
415
  scalar_value = value  # type: _ValidScalarInstanceType  # pytype: disable=annotation-type-mismatch
416
  if issubclass(field_info.scalar_type, enum.Enum):
417
    enum_value = scalar_value  # type: enum.Enum
418
    return str(enum_value.value)
419
  elif field_info.scalar_type == bool:
420
    bool_value = scalar_value  # type: bool  # pytype: disable=annotation-type-mismatch
421
    # use 0/1 instead of True/False for more compact serialization.
422
    return str(int(bool_value))
423
  elif issubclass(field_info.scalar_type, six.string_types):
424
    str_value = scalar_value  # type: str
425
    if any(char in str_value for char in ',=[]"'):
426
      return f'"{str_value}"'
427
  return str(value)
428

429

430
def _parse_serialized(
431
    values,
432
    hparams_map):
433
  """Parses hyperparameter values from a string into a python map.
434

435
  `values` is a string containing comma-separated `name=value` pairs.
436
  For each pair, the value of the hyperparameter named `name` is set to
437
  `value`.
438

439
  If a hyperparameter name appears multiple times in `values`, a ValueError
440
  is raised (e.g. 'a=1,a=2').
441

442
  The `value` in `name=value` must follows the syntax according to the
443
  type of the parameter:
444

445
  *  Scalar integer: A Python-parsable integer point value.  E.g.: 1,
446
     100, -12.
447
  *  Scalar float: A Python-parsable floating point value.  E.g.: 1.0,
448
     -.54e89.
449
  *  Boolean: True, False, true, false, 1, or 0.
450
  *  Scalar string: A non-empty sequence of characters, possibly surrounded by
451
       double-quotes.  E.g.: foo, bar_1, "foo,bar".
452
  *  List: A comma separated list of scalar values of the parameter type
453
     enclosed in square brackets.  E.g.: [1,2,3], [1.0,1e-12], [high,low].
454

455
  Args:
456
    values: Comma separated list of `name=value` pairs where 'value' must follow
457
      the syntax described above.
458
    hparams_map: A mapping from abbreviation to field info, detailing the
459
      expected type information for each known field.
460

461
  Returns:
462
    A python map mapping each name to either:
463
    * A scalar value.
464
    * A list of scalar values.
465

466
  Raises:
467
    ValueError: If there is a problem with input.
468
    * If `values` cannot be parsed.
469
    * If the same hyperparameter is assigned to twice.
470
    * If an unknown hyperparameter is assigned to.
471
    * If a list is assigned to a scalar hyperparameter.
472
  """
473
  results_dictionary = {}
474
  pos = 0
475
  while pos < len(values):
476
    m = _PARAM_RE.match(values, pos)
477
    if not m:
478
      raise ValueError(f'Malformed hyperparameter value: {values[pos:]}')
479
    pos = m.end()
480
    # Parse the values.
481
    m_dict = m.groupdict()
482
    name = m_dict['name']
483
    if name not in hparams_map:
484
      raise ValueError(f'Unknown hyperparameter: {name}.')
485
    if name in results_dictionary:
486
      raise ValueError(f'Duplicate assignment to hyperparameter \'{name}\'')
487
    scalar_type = hparams_map[name].scalar_type
488
    is_list = hparams_map[name].is_list
489

490
    # Set up correct parsing function (depending on whether scalar_type is a
491
    # bool)
492
    def parse_bool(value):
493
      if value in ['true', 'True']:
494
        return True
495
      elif value in ['false', 'False']:
496
        return False
497
      else:
498
        try:
499
          return bool(int(value))
500
        except ValueError:
501
          raise ValueError(
502
              f'Could not parse {value} as a boolean for hyperparameter '
503
              f'{name}.')
504

505
    if scalar_type == bool:
506
      parse = parse_bool
507
    elif issubclass(scalar_type, enum.Enum):
508
      enum_type = scalar_type  # type: Type[enum.Enum]
509
      enum_value_type = type(list(enum_type.__members__.values())[0].value)
510
      enum_value_parser = (
511
          parse_bool if enum_value_type == bool else enum_value_type)
512
      parse = lambda x: enum_type(enum_value_parser(x))
513
    else:
514
      parse = scalar_type
515

516
    # If a single value is provided
517
    if m_dict['val'] is not None:
518
      results_dictionary[name] = parse(m_dict['val'])
519
      if is_list:
520
        results_dictionary[name] = [results_dictionary[name]]
521

522
    # A quoted string, so trim the quotes.
523
    elif m_dict['strval'] is not None:
524
      results_dictionary[name] = parse(m_dict['strval'][1:-1])
525
      if is_list:
526
        results_dictionary[name] = [results_dictionary[name]]
527

528
    # If the assigned value is a list:
529
    elif m_dict['vals'] is not None:
530
      if not is_list:
531
        raise ValueError(f'Expected single value for hyperparameter {name}, '
532
                         f'but found {m_dict["vals"]}')
533
      list_str = m_dict['vals']
534
      if list_str[0] == '[' and list_str[-1] == ']':
535
        list_str = list_str[1:-1]
536
      elements = list(csv.reader([list_str]))[0]
537
      results_dictionary[name] = [parse(e.strip()) for e in elements]
538

539
    else:  # Not assigned a list or value
540
      raise ValueError(f'Found empty value for hyperparameter {name}.')
541

542
  return results_dictionary
543

544

545
def _build_hparams_map(hparams_class):
546
  """Constructs a map representing the metadata of an hparams class.
547

548
  Contains the information needed to serialize, deserialize, and validate fields
549
  of the class.
550

551
  Includes information for fields in the class passed in, as well as any nested
552
  hparams class fields that are created using hparam.nest(), recursively.
553

554
  Args:
555
    hparams_class: A class that is decorated with @hparam.s.
556

557
  Returns:
558
    A mapping per field of abbreviation (used for serialization) to field
559
    metatdata.
560

561
  Raises:
562
    TypeError:
563
      * if `hparams_class` was not decorated with @hparam.s.
564
      * if a nested class was not decorated with @hparam.s.
565
      * if `hparams_class` has a field that was not created using @hparam.field
566
        or @hparam.nest.
567
    KeyError:
568
      * if two fields in `hparams_class` or any of its nested classes use the
569
        same abbreviation.
570
  """
571
  if not attr.has(hparams_class):
572
    raise TypeError(
573
        'Inputs to _build_hparams_map should be classes decorated with '
574
        '@hparam.s')
575

576
  hparams_map = {}
577
  for attribute in attr.fields(hparams_class.__class__):
578
    path = [attribute.name]
579
    default = attribute.default
580
    # pytype: disable=invalid-annotation
581
    factory_type = attr.Factory  # type: Type[attr.Factory]  # pytype: disable=annotation-type-mismatch
582
    # pytype: enable=invalid-annotation
583
    if isinstance(default, factory_type):
584
      default = default.factory()
585
    if attr.has(default):  # Nested.
586
      if '__hparams_map__' not in default.__dict__:
587
        raise TypeError('Nested hparams classes must also be decorated with '
588
                        '@hparam.s.')
589
      submap = default.__hparams_map__
590
      prefix = ''
591
      if _PREFIX_KEY in attribute.metadata:
592
        prefix = attribute.metadata[_PREFIX_KEY] or ''
593
      for key, value in submap.items():
594
        abbrev = prefix + key
595
        if abbrev in hparams_map:
596
          raise KeyError(f'Abbrev {abbrev} is duplicated.')
597
        updated = copy.copy(value)
598
        updated.path = path + value.path
599
        hparams_map[abbrev] = updated
600
    else:  # Leaf node.
601
      if attribute.name == _SERIALIZED_ARG:
602
        continue
603
      if _ABBREV_KEY not in attribute.metadata:
604
        raise AssertionError(
605
            f'Could not find hparam metadata for field {attribute.name}. Did '
606
            'you create a field without using hparam.field()?')
607
      abbrev = attribute.metadata[_ABBREV_KEY]
608
      if abbrev in hparams_map:
609
        raise KeyError(f'Abbrev {abbrev} is duplicated.')
610
      field_info = _FieldInfo(
611
          path=path,
612
          scalar_type=attribute.metadata[_SCALAR_TYPE_KEY],
613
          is_list=attribute.metadata[_IS_LIST_KEY],
614
          default_value=attribute.converter(default))
615
      hparams_map[abbrev] = field_info
616
  return hparams_map
617

618

619
def s(wrapped, *attrs_args,
620
      **attrs_kwargs):
621
  """A class decorator for creating an hparams class.
622

623
  The resulting class is based on `attr` under the covers, but this wrapper
624
  provides additional features, such as serialization and deserialization in a
625
  format that is compatible with (now defunct) tensorflow.HParams, runtime type
626
  checking, implicit casting where safe to do so (int->float, scalar->list).
627
  Unlike tensorflow.HParams, this supports hierarchical nesting of parameters
628
  for better organization, aliasing parameters to short abbreviations for
629
  compact serialization while maintaining code readability, and support for Enum
630
  values.
631

632
  Example usage:
633
    @hparam.s
634
    class MyNestedHParams:
635
      learning_rate: float = hparam.field(abbrev='lr', default=0.1)
636
      layer_sizes: List[int] = hparam.field(abbrev='ls', default=[256, 64, 32])
637

638
    @hparam.s
639
    class MyHParams:
640
      nested_params: MyNestedHParams = hparam.nest(MyNestedHParams)
641
      non_nested_param: int = hparam.field(abbrev='nn', default=0)
642

643
  Args:
644
    wrapped: The class being decorated. It should only contain fields created
645
      using `hparam.field()` and `hparam.nest()`.
646
    *attrs_args: Arguments passed on to `attr.s`.
647
    **attrs_kwargs: Keyword arguments passed on to `attr.s`.
648

649
  Returns:
650
    The class with the modifications needed to support the additional hparams
651
    features.
652
  """
653

654
  def attrs_post_init(self):
655
    self.__hparams_map__ = _build_hparams_map(self)
656
    serialized = getattr(self, _SERIALIZED_ARG, '')
657
    if serialized:
658
      self.parse(serialized)
659
      setattr(self, _SERIALIZED_ARG, '')
660

661
  def setattr_impl(self, name, value):
662
    ready = '__hparams_map__' in self.__dict__
663
    # Don't mess with setattrs that are called by the attrs framework or during
664
    # __init__.
665
    if ready:
666
      attribute = getattr(attr.fields(self.__class__), name)
667
      if attribute and attribute.converter:
668
        value = attribute.converter(value)
669
    super(wrapped, self).__setattr__(name, value)  # pytype: disable=wrong-arg-types
670

671
  def serialize(self, readable=False, omit_defaults=False):
672
    if readable:
673
      d = attr.asdict(self, filter=lambda a, _: a.name != _SERIALIZED_ARG)
674
      return json.dumps(d, default=str)
675
    else:
676
      serialized = ''
677
      for key, field_info in self.__hparams_map__.items():
678
        parent = self
679
        for childname in field_info.path:
680
          parent = getattr(parent, childname)
681
        if not omit_defaults or parent != field_info.default_value:
682
          value = _serialize_value(parent, field_info)
683
          serialized += f'{key}={value},'
684
      return serialized[:-1]  # Prune trailing comma.
685

686
  def parse(self, serialized):
687
    parsed_fields = _parse_serialized(serialized, self.__hparams_map__)
688
    for abbrev, value in parsed_fields.items():
689
      field_info = self.__hparams_map__[abbrev]
690
      parent = self
691
      for i, childname in enumerate(field_info.path):
692
        if i != len(field_info.path) - 1:
693
          parent = getattr(parent, childname)
694
        else:
695
          try:
696
            setattr(parent, childname, value)
697
          except:
698
            error_field = '.'.join(field_info.path)
699
            raise RuntimeError(f'Error trying to assign value {value} to field '
700
                               f'{error_field}.')
701

702
  wrapped.__hparams_class__ = True
703
  setattr(
704
      wrapped, _SERIALIZED_ARG,
705
      attr.ib(
706
          default='',
707
          type=str,
708
          kw_only=False,
709
          validator=attr.validators.instance_of(six.string_types),
710
          repr=False))
711
  wrapped.__attrs_post_init__ = attrs_post_init
712
  wrapped.__setattr__ = setattr_impl
713
  wrapped.serialize = serialize
714
  wrapped.parse = parse
715
  wrapped = attr.s(wrapped, *attrs_args, **attrs_kwargs)  # pytype: disable=wrong-arg-types  # attr-stubs
716
  return wrapped
717

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.