google-research

Форк
0
138 строк · 4.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Random Network Distillation module."""
17

18
import numpy as np
19
import tensorflow as tf
20

21

22
class RunningStats:
23
  """Computes streaming statistics.
24

25
  Attributes:
26
    mean: running mean
27
    var: running variance
28
    count: number of samples that have been seen
29
  """
30

31
  def __init__(self, shape=()):
32
    self.mean = np.zeros(shape, dtype=np.float64)
33
    self.var = np.ones(shape, dtype=np.float64)
34
    self.count = 0
35

36
  def update(self, data):
37
    """Update the stats based on a batch of data."""
38
    batch_mean = np.mean(data, axis=0)
39
    batch_var = np.var(data, axis=0)
40
    batch_size = len(data)
41
    self.update_with_moments(batch_mean, batch_var, batch_size)
42

43
  def update_with_moments(self, batch_mean, batch_var, batch_size):
44
    """Distributed update of moments."""
45
    delta = batch_mean - self.mean
46
    new_count = self.count + batch_size
47

48
    if self.count == 0:
49
      new_mean = batch_mean
50
      new_var = batch_var
51
    else:
52
      new_mean = self.mean + delta * batch_size / new_count
53
      m_a = self.var * (self.count)
54
      m_b = batch_var * (batch_size)
55
      m2 = m_a + m_b + np.square(delta) * self.count * batch_size / (
56
          self.count + batch_size)
57
      new_var = m2 / (self.count + batch_size)
58

59
    self.mean = new_mean
60
    self.var = new_var
61
    self.count = new_count
62

63

64
class StateRND:
65
  """RND model from state space alone.
66

67
  Attributes:
68
    output_dim: dimension of the output
69
    predictor: prediction that maps an observation to a vector
70
    target: a random model that predictor tries to imitate
71
    opt: optimizer
72
    running_stats: object that keeps track of the running statistics of output
73
  """
74

75
  def __init__(self, input_dim=10, output_dim=5):
76
    """Initialize StateRND.
77

78
    Args:
79
      input_dim: dimension of the input
80
      output_dim: dimension of the output
81
    """
82
    self.output_dim = output_dim
83
    self.predictor = tf.keras.Sequential([
84
        tf.keras.layers.Dense(
85
            128, input_shape=(input_dim,), activation='leaky_relu'),
86
        tf.keras.layers.Dense(128, activation='leaky_relu'),
87
        tf.keras.layers.Dense(64, activation='relu'),
88
        tf.keras.layers.Dense(output_dim),
89
    ])
90
    self.target = tf.keras.Sequential([
91
        tf.keras.layers.Dense(
92
            128, input_shape=(input_dim,), activation='leaky_relu'),
93
        tf.keras.layers.Dense(128, activation='leaky_relu'),
94
        tf.keras.layers.Dense(output_dim),
95
    ])
96
    self.opt = tf.keras.optimizers.Adam(lr=1e-4)
97
    self.predictor.build()
98
    self.target.build()
99
    self.running_stats = RunningStats(shape=(input_dim,))
100

101
  def update_stats(self, states):
102
    """Update the running statistics of the states."""
103
    self.running_stats.update(np.array(states))
104

105
  def _whiten(self, states):
106
    """Whiten with running statistics."""
107
    centered = (states-self.running_stats.mean)/np.sqrt(self.running_stats.var)
108
    return centered.clip(-5, 5)
109

110
  def compute_intrinsic_reward(self, states):
111
    """Compute the intrinsic reward/novelty of a batch of states."""
112
    whitened_states = self._whiten(states)
113
    states = tf.convert_to_tensor(whitened_states)
114
    intrinsic_reward = self._diff_norm(states)
115
    return intrinsic_reward
116

117
  def train(self, states):
118
    """Train the predicotr network with a batch of states."""
119
    whitened_states = self._whiten(states)
120
    states = tf.convert_to_tensor(whitened_states)
121
    error = self._train(states)
122
    return {'prediction_loss': error}
123

124
  @tf.function
125
  def _train(self, states):
126
    pred_variables = self.predictor.variables
127
    with tf.GradientTape() as tape:
128
      tape.watch(pred_variables)
129
      error = self._diff_norm(states)
130
    grads = tape.gradient(error, pred_variables)
131
    self.opt.apply_gradients(zip(grads, pred_variables))
132
    return error
133

134
  @tf.function
135
  def _diff_norm(self, states):
136
    diff = self.predictor(states) - self.target(states)
137
    error = tf.reduce_mean(tf.reduce_sum(diff**2, axis=1))
138
    return error
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.