google-research

Форк
0
/
model_utils.py 
113 строк · 3.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for dealing with time-series forecasting models."""
17

18
import contextlib
19
import dataclasses
20
import random
21
from typing import Any, Callable, Iterator, List, Sequence, Tuple, Union
22

23
import numpy as np
24
import tensorflow as tf
25

26
WeightList = List[np.ndarray]
27

28

29
def set_seed(random_seed):
30
  # https://github.com/NVIDIA/framework-determinism
31
  random.seed(random_seed)
32
  np.random.seed(random_seed)
33
  tf.random.set_seed(random_seed)
34

35

36
@contextlib.contextmanager
37
def temporary_weights(
38
    *keras_objects,):
39
  """Resets objects weights to their initial values after the context.
40

41
  Calls the getters on entrance to the context and calling the yielded function
42
  resets the weights to this value. Regardless of if this function is called the
43
  weights will also be reset after the context.
44

45
  Args:
46
    *keras_objects: The keras objects to use temporary weights for. The weights
47
      will be gotten on entrance to the context block and reset to that value
48
      after the block.
49

50
  Yields:
51
    A function which will reset the weights to the initial value.
52
  """
53
  # Handle a single keras_object input
54
  starting_weights = freeze_weights(*keras_objects)
55

56
  def reset_fn():
57
    reset_object_weights(starting_weights)
58

59
  yield reset_fn
60

61
  reset_fn()
62

63

64
@dataclasses.dataclass(frozen=True)
65
class KerasObjectWithWeights:
66
  """Helper class to keep a Keras keras_object paired with its weights.
67

68
  Frozen to make sure the association doesn't get mixed up.
69
  """
70
  keras_object: Any
71
  initial_weights: List[np.ndarray]
72

73
  def reset_weights(self):
74
    # When layers are first created there may not be any weights.
75
    if self.initial_weights:
76
      self.keras_object.set_weights(self.initial_weights)
77

78
  @classmethod
79
  def from_object(cls, input_obj):
80
    return cls(input_obj, input_obj.get_weights())
81

82

83
def freeze_weights(
84
    *keras_objects,
85
    required_methods = ('get_weights', 'set_weights'),
86
):
87
  """Freeze the weights with the keras_object so they can be reset later.
88

89
  Args:
90
    *keras_objects: The Keras objects whose weights will be frozen.
91
    required_methods: The methods that each object must have.
92

93
  Returns:
94
    A tuple of objects that pairs the keras_object with its initial weights.
95
  """
96
  for obj in keras_objects:
97
    if not all(hasattr(obj, w_attr) for w_attr in required_methods):
98
      raise ValueError(
99
          f'All of the objects must have get and set weights: {obj}')
100

101
  return tuple(KerasObjectWithWeights.from_object(obj) for obj in keras_objects)
102

103

104
def reset_object_weights(
105
    layers_and_weights):
106
  """Resets a Keras keras_object to it's paired weights.
107

108
  Args:
109
    layers_and_weights: A tuple of objects that pairs a Keras keras_object with
110
      its initial weights. This is normally output from freeze_weights.
111
  """
112
  for current_layer in layers_and_weights:
113
    current_layer.reset_weights()
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.