google-research
198 строк · 6.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Custom getter that allows one to use custom values instead of tf.Variables.
17
18This converts functions with implicit weights -- f(x),
19to explicit -- f(x, params). See `VariableReplaceGetter` for more info.
20"""
21
22import collections
23import contextlib
24import tensorflow.compat.v1 as tf
25
26# Two modes for the custom getter to be in.
27_UseVariables = collections.namedtuple("UseVariables", [])
28_UseValues = collections.namedtuple("UseValues", ["values"])
29
30
31class VariableReplaceGetter(object):
32"""Getter that swaps out internal tf.Variable with tf.Tensor values.
33
34By default tensorflow hides away access to variables. A function that would
35normally be a function of both data and variables: f(variables, data)
36is presented as a function of just data: f(data) with variables hidden
37away in variable scopes. This custom getter can be used to create functions
38that use tensorflow's neural network construction libraries while exposing
39the underlying variables in a way that users can swap other values in.
40
41
42This can be used from things like evolutionary strategies, to unrolled
43optimization.
44
45```
46context = VariableReplaceGetter()
47mod = snt.Linear(123, custom_getter=context)
48
49with context.use_variables():
50y1 = mod(x1) # use variables
51
52values = context.get_variable_dict()
53# modify the current set of weights
54new_values = {k: v+1 for k,v in values.items()}
55
56with context.use_value_dict(new_values):
57y2 = mod(x1)
58```
59"""
60
61def __init__(self, verbose=False):
62"""Initializer.
63
64Args:
65verbose: bool If true, log when inside these contexts.
66"""
67self._verbose = verbose
68
69# Store initializers for each variable created.
70self._variable_initializer_dict = collections.OrderedDict()
71
72# Store the variables created for each variable created.
73self._variable_dict = collections.OrderedDict()
74
75# store the current state of the custom getter.
76# Is either an instance of:
77# * _UseVariables which causes the custom getter to return tf.Variable
78# created upon first call of tf.get_variables
79# * _UseValues: which uses the set tf.Tensors instead of variables
80self._context_state = None
81
82def __call__(self, getter, name, *args, **kwargs):
83"""The custom getter.
84
85Do not call directly, instead pass to a variable_scope.
86
87Args:
88getter: callable the default getter
89name: str name of variable to get
90*args: args forwarded to `getter`
91**kwargs: kwargs forwarded to `getter`
92
93Returns:
94tf.Tensor or tf.Variable
95"""
96# Only do variable replacement on trainable variables.
97# If not trainable, skip this custom getter.
98if not kwargs["trainable"]:
99if self._verbose:
100tf.logging.info("Skipping non-trainable %s" % name)
101
102# If in the _UseValues context ensure that the name of the non-trainable
103# variable has not been given a value.
104if isinstance(self._context_state, _UseValues):
105values = self._context_state.values
106if name in values:
107raise ValueError(
108"The name [%s] was found in the value_dict but it is a"
109" non-trainable variable. Either remove it from the"
110"value_dict or make it trainable!")
111
112return getter(name, *args, **kwargs)
113
114if self._context_state is None:
115raise ValueError("Only call in a `use_variables`, `use_value_dict`,"
116" context!")
117
118# Store the variable created by the default getter
119if self._verbose:
120tf.logging.info("Getting %s with normal getter" % name)
121orig_var = getter(name, *args, **kwargs)
122self._variable_dict[name] = orig_var
123
124# Store the default initializer
125if name not in self._variable_initializer_dict:
126if kwargs["initializer"] is not None:
127shape = tf.TensorShape(kwargs["shape"]).as_list()
128# pylint: disable=g-long-lambda
129init_fn = lambda: kwargs["initializer"](
130shape=shape, dtype=kwargs["dtype"])
131else:
132# If there is no initializer set, just use the initial value.
133init_fn = lambda: orig_var.initial_value
134
135self._variable_initializer_dict[name] = init_fn
136
137# This custom getter is in 1 of two modes -- _UseVariables or _UseValues.
138# The mode is determined by _context_state.
139if isinstance(self._context_state, _UseVariables):
140return orig_var
141
142elif isinstance(self._context_state, _UseValues):
143if self._verbose:
144tf.logging.info("Getting %s from values" % name)
145values = self._context_state.values
146if name not in values:
147message = ("Name: %s not specified in the values. \nValid names:\n %s" %
148(name, "\n".join(" %s" % k for k in values.keys())))
149raise ValueError(message)
150if self._verbose:
151tf.logging.info("Tensor returned %s" % values[name])
152return values[name]
153else:
154raise ValueError("Bad type of self._context_state. Got [%s]" %
155type(self._context_state))
156
157@contextlib.contextmanager
158def use_variables(self):
159"""Context for using the tf.Variables which are stored in the tf Graph."""
160if self._context_state is not None:
161raise NotImplementedError("Nested contexts not allowed at this point.")
162self._context_state = _UseVariables()
163yield
164self._context_state = None
165
166@contextlib.contextmanager
167def use_value_dict(self, values):
168"""Context for using the values, instead of the local variables.
169
170Args:
171values: dict maps name to value to use when a tf.get_variable is called
172
173Yields:
174None
175"""
176if self._context_state is not None:
177raise NotImplementedError("Nested contexts not allowed at this point." "")
178self._context_state = _UseValues(values=values)
179yield
180self._context_state = None
181
182def get_initialized_value_dict(self):
183"""Return a dictionary of names to tf.Tensor with the initial values.
184
185Returns:
186collections.OrderedDict with initialized tf.Tensor values.
187"""
188d = [(name, init())
189for name, init in self._variable_initializer_dict.items()]
190return collections.OrderedDict(d)
191
192def get_variable_dict(self):
193"""Return a dictionary of names to tf.Variable with the variables created.
194
195Returns:
196collections.OrderedDict with tf.Variable.
197"""
198return self._variable_dict
199