google-research

Форк
0
/
variable_replace.py 
198 строк · 6.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Custom getter that allows one to use custom values instead of tf.Variables.
17

18
This converts functions with implicit weights -- f(x),
19
to explicit -- f(x, params). See `VariableReplaceGetter` for more info.
20
"""
21

22
import collections
23
import contextlib
24
import tensorflow.compat.v1 as tf
25

26
# Two modes for the custom getter to be in.
27
_UseVariables = collections.namedtuple("UseVariables", [])
28
_UseValues = collections.namedtuple("UseValues", ["values"])
29

30

31
class VariableReplaceGetter(object):
32
  """Getter that swaps out internal tf.Variable with tf.Tensor values.
33

34
  By default tensorflow hides away access to variables. A function that would
35
  normally be a function of both data and variables: f(variables, data)
36
  is presented as a function of just data: f(data) with variables hidden
37
  away in variable scopes. This custom getter can be used to create functions
38
  that use tensorflow's neural network construction libraries while exposing
39
  the underlying variables in a way that users can swap other values in.
40

41

42
  This can be used from things like evolutionary strategies, to unrolled
43
  optimization.
44

45
  ```
46
  context = VariableReplaceGetter()
47
  mod = snt.Linear(123, custom_getter=context)
48

49
  with context.use_variables():
50
    y1 = mod(x1) # use variables
51

52
  values = context.get_variable_dict()
53
  # modify the current set of weights
54
  new_values = {k: v+1 for k,v in values.items()}
55

56
  with context.use_value_dict(new_values):
57
    y2 = mod(x1)
58
  ```
59
  """
60

61
  def __init__(self, verbose=False):
62
    """Initializer.
63

64
    Args:
65
      verbose: bool If true, log when inside these contexts.
66
    """
67
    self._verbose = verbose
68

69
    # Store initializers for each variable created.
70
    self._variable_initializer_dict = collections.OrderedDict()
71

72
    # Store the variables created for each variable created.
73
    self._variable_dict = collections.OrderedDict()
74

75
    # store the current state of the custom getter.
76
    # Is either an instance of:
77
    #  * _UseVariables which causes the custom getter to return tf.Variable
78
    #    created upon first call of tf.get_variables
79
    #  * _UseValues: which uses the set tf.Tensors instead of variables
80
    self._context_state = None
81

82
  def __call__(self, getter, name, *args, **kwargs):
83
    """The custom getter.
84

85
    Do not call directly, instead pass to a variable_scope.
86

87
    Args:
88
      getter: callable the default getter
89
      name: str name of variable to get
90
      *args: args forwarded to `getter`
91
      **kwargs: kwargs forwarded to `getter`
92

93
    Returns:
94
      tf.Tensor or tf.Variable
95
    """
96
    # Only do variable replacement on trainable variables.
97
    # If not trainable, skip this custom getter.
98
    if not kwargs["trainable"]:
99
      if self._verbose:
100
        tf.logging.info("Skipping non-trainable %s" % name)
101

102
      # If in the _UseValues context ensure that the name of the non-trainable
103
      # variable has not been given a value.
104
      if isinstance(self._context_state, _UseValues):
105
        values = self._context_state.values
106
        if name in values:
107
          raise ValueError(
108
              "The name [%s] was found in the value_dict but it is a"
109
              " non-trainable variable. Either remove it from the"
110
              "value_dict or make it trainable!")
111

112
      return getter(name, *args, **kwargs)
113

114
    if self._context_state is None:
115
      raise ValueError("Only call in a `use_variables`, `use_value_dict`,"
116
                       " context!")
117

118
    # Store the variable created by the default getter
119
    if self._verbose:
120
      tf.logging.info("Getting %s with normal getter" % name)
121
    orig_var = getter(name, *args, **kwargs)
122
    self._variable_dict[name] = orig_var
123

124
    # Store the default initializer
125
    if name not in self._variable_initializer_dict:
126
      if kwargs["initializer"] is not None:
127
        shape = tf.TensorShape(kwargs["shape"]).as_list()
128
        # pylint: disable=g-long-lambda
129
        init_fn = lambda: kwargs["initializer"](
130
            shape=shape, dtype=kwargs["dtype"])
131
      else:
132
        # If there is no initializer set, just use the initial value.
133
        init_fn = lambda: orig_var.initial_value
134

135
      self._variable_initializer_dict[name] = init_fn
136

137
    # This custom getter is in 1 of two modes -- _UseVariables or _UseValues.
138
    # The mode is determined by _context_state.
139
    if isinstance(self._context_state, _UseVariables):
140
      return orig_var
141

142
    elif isinstance(self._context_state, _UseValues):
143
      if self._verbose:
144
        tf.logging.info("Getting %s from values" % name)
145
      values = self._context_state.values
146
      if name not in values:
147
        message = ("Name: %s not specified in the values. \nValid names:\n %s" %
148
                   (name, "\n".join("    %s" % k for k in values.keys())))
149
        raise ValueError(message)
150
      if self._verbose:
151
        tf.logging.info("Tensor returned %s" % values[name])
152
      return values[name]
153
    else:
154
      raise ValueError("Bad type of self._context_state. Got [%s]" %
155
                       type(self._context_state))
156

157
  @contextlib.contextmanager
158
  def use_variables(self):
159
    """Context for using the tf.Variables which are stored in the tf Graph."""
160
    if self._context_state is not None:
161
      raise NotImplementedError("Nested contexts not allowed at this point.")
162
    self._context_state = _UseVariables()
163
    yield
164
    self._context_state = None
165

166
  @contextlib.contextmanager
167
  def use_value_dict(self, values):
168
    """Context for using the values, instead of the local variables.
169

170
    Args:
171
      values: dict maps name to value to use when a tf.get_variable is called
172

173
    Yields:
174
      None
175
    """
176
    if self._context_state is not None:
177
      raise NotImplementedError("Nested contexts not allowed at this point." "")
178
    self._context_state = _UseValues(values=values)
179
    yield
180
    self._context_state = None
181

182
  def get_initialized_value_dict(self):
183
    """Return a dictionary of names to tf.Tensor with the initial values.
184

185
    Returns:
186
      collections.OrderedDict with initialized tf.Tensor values.
187
    """
188
    d = [(name, init())
189
         for name, init in self._variable_initializer_dict.items()]
190
    return collections.OrderedDict(d)
191

192
  def get_variable_dict(self):
193
    """Return a dictionary of names to tf.Variable with the variables created.
194

195
    Returns:
196
      collections.OrderedDict with tf.Variable.
197
    """
198
    return self._variable_dict
199

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.