google-research
138 строк · 4.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Random Network Distillation module."""
17
18import numpy as np
19import tensorflow as tf
20
21
22class RunningStats:
23"""Computes streaming statistics.
24
25Attributes:
26mean: running mean
27var: running variance
28count: number of samples that have been seen
29"""
30
31def __init__(self, shape=()):
32self.mean = np.zeros(shape, dtype=np.float64)
33self.var = np.ones(shape, dtype=np.float64)
34self.count = 0
35
36def update(self, data):
37"""Update the stats based on a batch of data."""
38batch_mean = np.mean(data, axis=0)
39batch_var = np.var(data, axis=0)
40batch_size = len(data)
41self.update_with_moments(batch_mean, batch_var, batch_size)
42
43def update_with_moments(self, batch_mean, batch_var, batch_size):
44"""Distributed update of moments."""
45delta = batch_mean - self.mean
46new_count = self.count + batch_size
47
48if self.count == 0:
49new_mean = batch_mean
50new_var = batch_var
51else:
52new_mean = self.mean + delta * batch_size / new_count
53m_a = self.var * (self.count)
54m_b = batch_var * (batch_size)
55m2 = m_a + m_b + np.square(delta) * self.count * batch_size / (
56self.count + batch_size)
57new_var = m2 / (self.count + batch_size)
58
59self.mean = new_mean
60self.var = new_var
61self.count = new_count
62
63
64class StateRND:
65"""RND model from state space alone.
66
67Attributes:
68output_dim: dimension of the output
69predictor: prediction that maps an observation to a vector
70target: a random model that predictor tries to imitate
71opt: optimizer
72running_stats: object that keeps track of the running statistics of output
73"""
74
75def __init__(self, input_dim=10, output_dim=5):
76"""Initialize StateRND.
77
78Args:
79input_dim: dimension of the input
80output_dim: dimension of the output
81"""
82self.output_dim = output_dim
83self.predictor = tf.keras.Sequential([
84tf.keras.layers.Dense(
85128, input_shape=(input_dim,), activation='leaky_relu'),
86tf.keras.layers.Dense(128, activation='leaky_relu'),
87tf.keras.layers.Dense(64, activation='relu'),
88tf.keras.layers.Dense(output_dim),
89])
90self.target = tf.keras.Sequential([
91tf.keras.layers.Dense(
92128, input_shape=(input_dim,), activation='leaky_relu'),
93tf.keras.layers.Dense(128, activation='leaky_relu'),
94tf.keras.layers.Dense(output_dim),
95])
96self.opt = tf.keras.optimizers.Adam(lr=1e-4)
97self.predictor.build()
98self.target.build()
99self.running_stats = RunningStats(shape=(input_dim,))
100
101def update_stats(self, states):
102"""Update the running statistics of the states."""
103self.running_stats.update(np.array(states))
104
105def _whiten(self, states):
106"""Whiten with running statistics."""
107centered = (states-self.running_stats.mean)/np.sqrt(self.running_stats.var)
108return centered.clip(-5, 5)
109
110def compute_intrinsic_reward(self, states):
111"""Compute the intrinsic reward/novelty of a batch of states."""
112whitened_states = self._whiten(states)
113states = tf.convert_to_tensor(whitened_states)
114intrinsic_reward = self._diff_norm(states)
115return intrinsic_reward
116
117def train(self, states):
118"""Train the predicotr network with a batch of states."""
119whitened_states = self._whiten(states)
120states = tf.convert_to_tensor(whitened_states)
121error = self._train(states)
122return {'prediction_loss': error}
123
124@tf.function
125def _train(self, states):
126pred_variables = self.predictor.variables
127with tf.GradientTape() as tape:
128tape.watch(pred_variables)
129error = self._diff_norm(states)
130grads = tape.gradient(error, pred_variables)
131self.opt.apply_gradients(zip(grads, pred_variables))
132return error
133
134@tf.function
135def _diff_norm(self, states):
136diff = self.predictor(states) - self.target(states)
137error = tf.reduce_mean(tf.reduce_sum(diff**2, axis=1))
138return error
139