google-research

Форк
0
153 строки · 5.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=unused-argument
17
"""Build UVFA with image observation for state input."""
18

19
from __future__ import absolute_import
20
from __future__ import division
21

22
import numpy as np
23
import tensorflow as tf
24
import tensorflow.keras.layers as layers
25

26
from hal.agent.common import tensor_concat
27
from hal.agent.UVFA_repr_tf2.uvfa_state import StateUVFA2
28

29

30
class ImageUVFA2(StateUVFA2):
31
  """UVFA that uses the image observation.
32

33
  Attributes:
34
    layers_variables: weight variables of all layers
35
  """
36

37
  def __init__(self, cfg):
38
    self.layers_variables = {}
39
    StateUVFA2.__init__(self, cfg)
40

41
  def _process_image(self, input_shape, cfg, name, embedding_length):
42
    inputs = layers.Input(shape=input_shape)
43
    goal_embedding = layers.Input(shape=(embedding_length))
44
    expand_dims = tf.keras.layers.Lambda(
45
        lambda inputs: tf.expand_dims(inputs[0], axis=inputs[1]))
46
    out = inputs
47
    for cfg in cfg.conv_layer_config:
48
      if cfg[0] < 0:
49
        conv = layers.Conv2D(
50
            filters=-cfg[0],
51
            kernel_size=cfg[1],
52
            strides=cfg[2],
53
            activation=tf.nn.relu,
54
            padding='SAME')
55
        out = conv(out)
56
      else:
57
        # Film
58
        conv = layers.Conv2D(
59
            filters=cfg[0],
60
            kernel_size=cfg[1],
61
            strides=cfg[2],
62
            activation=None,
63
            padding='SAME')
64
        out = conv(out)
65
        out = layers.BatchNormalization(center=False, scale=False)(out)
66
        gamma = layers.Dense(cfg[0])(goal_embedding)
67
        beta = layers.Dense(cfg[0])(goal_embedding)
68
        gamma = expand_dims((expand_dims((gamma, 1)), 1))
69
        beta = expand_dims((expand_dims((beta, 1)), 1))
70
        out = layers.Multiply()([out, gamma])
71
        out = layers.Add()([out, beta])
72
        out = layers.ReLU()(out)
73
    all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding}
74
    overall_layer = tf.keras.Model(
75
        name='vl_embedding', inputs=all_inputs, outputs=out)
76
    return overall_layer
77

78
  def build_q_discrete(self, cfg, name, embedding_length):
79
    """"Build the q value network.
80

81
    Args:
82
      cfg: configuration object
83
      name: name of the model
84
      embedding_length: length of the embedding of the instruction
85

86
    Returns:
87
      the q value network
88
    """
89
    input_shape = (cfg.img_resolution, cfg.img_resolution, 3)
90
    inputs = tf.keras.layers.Input(shape=input_shape)
91
    goal_embedding = tf.keras.layers.Input(shape=(embedding_length))
92
    all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding}
93
    factors = [8, 10, 10]
94

95
    process_layer = self._process_image(input_shape, cfg, name,
96
                                        embedding_length)
97
    process_layer.build(input_shape)
98
    out_shape = process_layer.output_shape
99
    final_layer = DiscreteFinalLayer(factors, out_shape)
100

101
    processed_input = process_layer(all_inputs)
102
    processed_all_inputs = {
103
        'state_input': processed_input,
104
        'goal_embedding': goal_embedding
105
    }
106
    q_out = final_layer(processed_all_inputs)
107
    model = tf.keras.Model(name=name, inputs=all_inputs, outputs=q_out)
108
    return model
109

110

111
class DiscreteFinalLayer(layers.Layer):
112
  """Keras layer for projection.
113

114
  Attributes:
115
      factors: size of each action axis
116
      out_shape: shape of the action
117
  """
118

119
  def __init__(self, factors, out_shape):
120
    super(DiscreteFinalLayer, self).__init__()
121
    self.factors = factors
122
    self.out_shape = out_shape
123
    self._initializer = tf.initializers.glorot_uniform()
124
    self._projection_mat = tf.Variable(
125
        self._initializer(shape=(1, sum(factors), np.prod(out_shape[1:-1]))),
126
        trainable=True,
127
        name='projection_matrix')
128
    self._dense_layer = layers.Dense(out_shape[-1])
129
    self._conv_layer_1 = layers.Conv2D(100, 1, 1)
130
    self._conv_layer_2 = layers.Conv2D(32, 1, 1)
131
    self._conv_layer_3 = layers.Conv2D(1, 1, 1)
132

133
  def call(self, inputs):
134
    goal_embedding = inputs['goal_embedding']
135
    state_inputs = inputs['state_input']
136
    projection_mat = tf.tile(self._projection_mat,
137
                             [tf.shape(state_inputs)[0], 1, 1])
138
    out = tf.reshape(state_inputs,
139
                     (-1, np.prod(self.out_shape[1:-1]), self.out_shape[-1]))
140
    out = tf.matmul(projection_mat, out)
141
    # [B, factor[0], s3] [B, factor[1], s3] [B, factor[2], s3]
142
    fac1, fac2, fac3 = tf.split(out, self.factors, axis=1)
143
    out = tensor_concat(fac1, fac2, fac3)  # [B, f1, f2, f3, s3]
144
    # [B, 800, s3*3]
145
    out = tf.reshape(out, [-1, np.prod(self.factors), self.out_shape[-1] * 3])
146
    goal_tile = tf.expand_dims(self._dense_layer(goal_embedding), 1)
147
    goal_tile = tf.tile(goal_tile, multiples=[1, np.prod(self.factors), 1])
148
    out = tf.concat([out, goal_tile], axis=-1)
149
    out = tf.expand_dims(out, axis=1)
150
    out = tf.nn.relu(self._conv_layer_1(out))
151
    out = tf.nn.relu(self._conv_layer_2(out))
152
    out = self._conv_layer_3(out)
153
    return tf.squeeze(out, axis=[1, 3])
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.