google-research
153 строки · 5.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=unused-argument
17"""Build UVFA with image observation for state input."""
18
19from __future__ import absolute_import
20from __future__ import division
21
22import numpy as np
23import tensorflow as tf
24import tensorflow.keras.layers as layers
25
26from hal.agent.common import tensor_concat
27from hal.agent.UVFA_repr_tf2.uvfa_state import StateUVFA2
28
29
30class ImageUVFA2(StateUVFA2):
31"""UVFA that uses the image observation.
32
33Attributes:
34layers_variables: weight variables of all layers
35"""
36
37def __init__(self, cfg):
38self.layers_variables = {}
39StateUVFA2.__init__(self, cfg)
40
41def _process_image(self, input_shape, cfg, name, embedding_length):
42inputs = layers.Input(shape=input_shape)
43goal_embedding = layers.Input(shape=(embedding_length))
44expand_dims = tf.keras.layers.Lambda(
45lambda inputs: tf.expand_dims(inputs[0], axis=inputs[1]))
46out = inputs
47for cfg in cfg.conv_layer_config:
48if cfg[0] < 0:
49conv = layers.Conv2D(
50filters=-cfg[0],
51kernel_size=cfg[1],
52strides=cfg[2],
53activation=tf.nn.relu,
54padding='SAME')
55out = conv(out)
56else:
57# Film
58conv = layers.Conv2D(
59filters=cfg[0],
60kernel_size=cfg[1],
61strides=cfg[2],
62activation=None,
63padding='SAME')
64out = conv(out)
65out = layers.BatchNormalization(center=False, scale=False)(out)
66gamma = layers.Dense(cfg[0])(goal_embedding)
67beta = layers.Dense(cfg[0])(goal_embedding)
68gamma = expand_dims((expand_dims((gamma, 1)), 1))
69beta = expand_dims((expand_dims((beta, 1)), 1))
70out = layers.Multiply()([out, gamma])
71out = layers.Add()([out, beta])
72out = layers.ReLU()(out)
73all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding}
74overall_layer = tf.keras.Model(
75name='vl_embedding', inputs=all_inputs, outputs=out)
76return overall_layer
77
78def build_q_discrete(self, cfg, name, embedding_length):
79""""Build the q value network.
80
81Args:
82cfg: configuration object
83name: name of the model
84embedding_length: length of the embedding of the instruction
85
86Returns:
87the q value network
88"""
89input_shape = (cfg.img_resolution, cfg.img_resolution, 3)
90inputs = tf.keras.layers.Input(shape=input_shape)
91goal_embedding = tf.keras.layers.Input(shape=(embedding_length))
92all_inputs = {'state_input': inputs, 'goal_embedding': goal_embedding}
93factors = [8, 10, 10]
94
95process_layer = self._process_image(input_shape, cfg, name,
96embedding_length)
97process_layer.build(input_shape)
98out_shape = process_layer.output_shape
99final_layer = DiscreteFinalLayer(factors, out_shape)
100
101processed_input = process_layer(all_inputs)
102processed_all_inputs = {
103'state_input': processed_input,
104'goal_embedding': goal_embedding
105}
106q_out = final_layer(processed_all_inputs)
107model = tf.keras.Model(name=name, inputs=all_inputs, outputs=q_out)
108return model
109
110
111class DiscreteFinalLayer(layers.Layer):
112"""Keras layer for projection.
113
114Attributes:
115factors: size of each action axis
116out_shape: shape of the action
117"""
118
119def __init__(self, factors, out_shape):
120super(DiscreteFinalLayer, self).__init__()
121self.factors = factors
122self.out_shape = out_shape
123self._initializer = tf.initializers.glorot_uniform()
124self._projection_mat = tf.Variable(
125self._initializer(shape=(1, sum(factors), np.prod(out_shape[1:-1]))),
126trainable=True,
127name='projection_matrix')
128self._dense_layer = layers.Dense(out_shape[-1])
129self._conv_layer_1 = layers.Conv2D(100, 1, 1)
130self._conv_layer_2 = layers.Conv2D(32, 1, 1)
131self._conv_layer_3 = layers.Conv2D(1, 1, 1)
132
133def call(self, inputs):
134goal_embedding = inputs['goal_embedding']
135state_inputs = inputs['state_input']
136projection_mat = tf.tile(self._projection_mat,
137[tf.shape(state_inputs)[0], 1, 1])
138out = tf.reshape(state_inputs,
139(-1, np.prod(self.out_shape[1:-1]), self.out_shape[-1]))
140out = tf.matmul(projection_mat, out)
141# [B, factor[0], s3] [B, factor[1], s3] [B, factor[2], s3]
142fac1, fac2, fac3 = tf.split(out, self.factors, axis=1)
143out = tensor_concat(fac1, fac2, fac3) # [B, f1, f2, f3, s3]
144# [B, 800, s3*3]
145out = tf.reshape(out, [-1, np.prod(self.factors), self.out_shape[-1] * 3])
146goal_tile = tf.expand_dims(self._dense_layer(goal_embedding), 1)
147goal_tile = tf.tile(goal_tile, multiples=[1, np.prod(self.factors), 1])
148out = tf.concat([out, goal_tile], axis=-1)
149out = tf.expand_dims(out, axis=1)
150out = tf.nn.relu(self._conv_layer_1(out))
151out = tf.nn.relu(self._conv_layer_2(out))
152out = self._conv_layer_3(out)
153return tf.squeeze(out, axis=[1, 3])
154