google-research
97 строк · 3.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Component for model networks."""
17
18import abc19from typing import Mapping, Optional20
21import tensorflow.compat.v1 as tf # tf22
23# Read-only map from name to Tensor. Used in components that take or return
24# multiple named Tensors (e.g. network).
25TensorMap = Mapping[str, tf.Tensor]26
27
28class Network(metaclass=abc.ABCMeta):29"""Interface for the model network component.30
31Note that the output Tensors are expected to be embeddings (e.g. "PreLogits"
32in InceptionV3) instead of predictions (e.g. "Logits"). Head components
33accept embeddings to produce predictions; this design allows us to support
34multi-head and reusing the same network body for different predictions
35(e.g. regression and classification).
36
37Implementations should provide a constructor that takes parameters needed for
38the specific model.
39
40Example usage:
41net = network.Inception_v3(inception_v3_params)
42prelogits = net.build(images)['out']
43"""
44
45# ---------------------------------------------------------------------------46# Standard keys for input Tensors to build().47
48IMAGES = 'Images'49
50# ---------------------------------------------------------------------------51# Standard keys for output Tensors from build().52
53PRE_LOGITS = 'PreLogits'54LOGITS = 'Logits'55PROBABILITIES_TENSOR = 'ProbabilitiesTensor'56PROBABILITIES = 'Probabilities'57ARM_OUTPUT_TENSOR = 'ArmOutputTensor'58
59@abc.abstractmethod60def build(self, inputs):61"""Builds the network.62
63Args:
64inputs: a map from input string names to tensors.
65
66Returns:
67A map from output string names to tensors.
68"""
69
70@staticmethod71def _get_tensor(tmap,72name,73expected_rank = None):74"""Returns the specified Tensor from a TensorMap, with error-checking.75
76Args:
77tmap: a mapping from string names to Tensors.
78name: the name of the Tensor to return.
79expected_rank: expected rank of the Tensor, for error-checking. Note that
80this checks static shape (e.g. via tensor.get_shape()). Defaults to not
81checked.
82
83Returns:
84The selected Tensor.
85
86Raises:
87ValueError: tmap does not contain the specified Tensor, or the Tensor is
88not of expected rank.
89"""
90tensor = tmap.get(name, None)91if tensor is None:92raise ValueError('Tensor {} not found in TensorMap.'.format(name))93rank = len(tensor.get_shape())94if expected_rank is not None and rank != expected_rank:95raise ValueError('Tensor {} is of rank {}, but expected {}.'.format(96name, rank, expected_rank))97return tensor98