google-research

Форк
0
97 строк · 3.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Component for model networks."""
17

18
import abc
19
from typing import Mapping, Optional
20

21
import tensorflow.compat.v1 as tf  # tf
22

23
# Read-only map from name to Tensor. Used in components that take or return
24
# multiple named Tensors (e.g. network).
25
TensorMap = Mapping[str, tf.Tensor]
26

27

28
class Network(metaclass=abc.ABCMeta):
29
  """Interface for the model network component.
30

31
  Note that the output Tensors are expected to be embeddings (e.g. "PreLogits"
32
  in InceptionV3) instead of predictions (e.g. "Logits"). Head components
33
  accept embeddings to produce predictions; this design allows us to support
34
  multi-head and reusing the same network body for different predictions
35
  (e.g. regression and classification).
36

37
  Implementations should provide a constructor that takes parameters needed for
38
  the specific model.
39

40
  Example usage:
41
  net = network.Inception_v3(inception_v3_params)
42
  prelogits = net.build(images)['out']
43
  """
44

45
  # ---------------------------------------------------------------------------
46
  # Standard keys for input Tensors to build().
47

48
  IMAGES = 'Images'
49

50
  # ---------------------------------------------------------------------------
51
  # Standard keys for output Tensors from build().
52

53
  PRE_LOGITS = 'PreLogits'
54
  LOGITS = 'Logits'
55
  PROBABILITIES_TENSOR = 'ProbabilitiesTensor'
56
  PROBABILITIES = 'Probabilities'
57
  ARM_OUTPUT_TENSOR = 'ArmOutputTensor'
58

59
  @abc.abstractmethod
60
  def build(self, inputs):
61
    """Builds the network.
62

63
    Args:
64
      inputs: a map from input string names to tensors.
65

66
    Returns:
67
      A map from output string names to tensors.
68
    """
69

70
  @staticmethod
71
  def _get_tensor(tmap,
72
                  name,
73
                  expected_rank = None):
74
    """Returns the specified Tensor from a TensorMap, with error-checking.
75

76
    Args:
77
      tmap: a mapping from string names to Tensors.
78
      name: the name of the Tensor to return.
79
      expected_rank: expected rank of the Tensor, for error-checking. Note that
80
        this checks static shape (e.g. via tensor.get_shape()). Defaults to not
81
        checked.
82

83
    Returns:
84
      The selected Tensor.
85

86
    Raises:
87
      ValueError: tmap does not contain the specified Tensor, or the Tensor is
88
        not of expected rank.
89
    """
90
    tensor = tmap.get(name, None)
91
    if tensor is None:
92
      raise ValueError('Tensor {} not found in TensorMap.'.format(name))
93
    rank = len(tensor.get_shape())
94
    if expected_rank is not None and rank != expected_rank:
95
      raise ValueError('Tensor {} is of rank {}, but expected {}.'.format(
96
          name, rank, expected_rank))
97
    return tensor
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.