google-research
113 строк · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for dealing with time-series forecasting models."""
17
18import contextlib19import dataclasses20import random21from typing import Any, Callable, Iterator, List, Sequence, Tuple, Union22
23import numpy as np24import tensorflow as tf25
26WeightList = List[np.ndarray]27
28
29def set_seed(random_seed):30# https://github.com/NVIDIA/framework-determinism31random.seed(random_seed)32np.random.seed(random_seed)33tf.random.set_seed(random_seed)34
35
36@contextlib.contextmanager37def temporary_weights(38*keras_objects,):39"""Resets objects weights to their initial values after the context.40
41Calls the getters on entrance to the context and calling the yielded function
42resets the weights to this value. Regardless of if this function is called the
43weights will also be reset after the context.
44
45Args:
46*keras_objects: The keras objects to use temporary weights for. The weights
47will be gotten on entrance to the context block and reset to that value
48after the block.
49
50Yields:
51A function which will reset the weights to the initial value.
52"""
53# Handle a single keras_object input54starting_weights = freeze_weights(*keras_objects)55
56def reset_fn():57reset_object_weights(starting_weights)58
59yield reset_fn60
61reset_fn()62
63
64@dataclasses.dataclass(frozen=True)65class KerasObjectWithWeights:66"""Helper class to keep a Keras keras_object paired with its weights.67
68Frozen to make sure the association doesn't get mixed up.
69"""
70keras_object: Any71initial_weights: List[np.ndarray]72
73def reset_weights(self):74# When layers are first created there may not be any weights.75if self.initial_weights:76self.keras_object.set_weights(self.initial_weights)77
78@classmethod79def from_object(cls, input_obj):80return cls(input_obj, input_obj.get_weights())81
82
83def freeze_weights(84*keras_objects,85required_methods = ('get_weights', 'set_weights'),86):87"""Freeze the weights with the keras_object so they can be reset later.88
89Args:
90*keras_objects: The Keras objects whose weights will be frozen.
91required_methods: The methods that each object must have.
92
93Returns:
94A tuple of objects that pairs the keras_object with its initial weights.
95"""
96for obj in keras_objects:97if not all(hasattr(obj, w_attr) for w_attr in required_methods):98raise ValueError(99f'All of the objects must have get and set weights: {obj}')100
101return tuple(KerasObjectWithWeights.from_object(obj) for obj in keras_objects)102
103
104def reset_object_weights(105layers_and_weights):106"""Resets a Keras keras_object to it's paired weights.107
108Args:
109layers_and_weights: A tuple of objects that pairs a Keras keras_object with
110its initial weights. This is normally output from freeze_weights.
111"""
112for current_layer in layers_and_weights:113current_layer.reset_weights()114