google-research

Форк
0
/
multi_task.py 
238 строк · 7.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generic Multi task architecture."""
17

18
import copy
19
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, Generic
20

21
import gin
22

23

24
T = TypeVar('T')
25

26

27
class NamedLists(dict, Generic[T]):
28
  """A generic architecture for multi tasks with potentially several levels."""
29

30
  def __init__(self, layers):
31
    layers = {k: list(v) for k, v in layers.items()}
32
    super().__init__(layers)
33

34
  def __getattr__(self, attr):
35
    return self[attr]
36

37
  @property
38
  def levels(self):
39
    return list(self.values())
40

41
  @property
42
  def size(self):
43
    return sum(len(x) for x in self.values())
44

45
  def constant_copy(self, value):
46
    """Returns a copy of the structure with only the same value everywhere."""
47
    return NamedLists(
48
        layers={k: [value for _ in v] for k, v in self.items()})
49

50
  def copy(self):
51
    """Returns a copy of the NamedLists."""
52
    return NamedLists(copy.deepcopy(super().copy()))
53

54
  def pack(self, values, default_value=None):
55
    """Packs the values in a NamedLists with the same structure as self."""
56
    result = self.constant_copy(default_value)
57
    it = result.__iter__()
58
    for val in values:
59
      next(it)  # pytype: disable=wrong-arg-types  # dynamic-method-lookup
60
      it._level[it._idx] = val  # pylint: disable=protected-access  # pytype: disable=unsupported-operands  # dynamic-method-lookup
61
    return result
62

63
  def flatten(self, empty_value=None):
64
    result = {}
65
    for name, values in self.items():
66
      for i, value in enumerate(values):
67
        result[f'{name}/{i}'] = value
68
      if not values:  # special case for empty list to keep the structure.
69
        result[name + '/'] = empty_value
70
    return result
71

72
  @staticmethod
73
  def unflatten(values):
74
    """Unflatten a dict of values that have been previously flattened."""
75
    result = dict()
76
    for name, value in values.items():
77
      idx = name.rfind('/')
78
      key = name[:idx]
79
      if key not in result:
80
        result[key] = []
81
      if idx != len(name) - 1:
82
        result[key].append(value)
83
    return NamedLists(result)
84

85
  class _Iterator:
86
    """Iterator on NamedLists."""
87

88
    def __init__(self, container):
89
      self._level_iter = iter(container.values())
90
      self._level = None
91
      self._idx = -1
92

93
    def __next__(self):
94
      self._idx += 1
95
      if self._level is None or self._idx >= len(self._level):
96
        self._level = next(self._level_iter)  # Might raise StopIteration here.
97
        self._idx = -1
98
        return self.__next__()
99
      return self._level[self._idx]
100

101
  def __iter__(self):
102
    return NamedLists._Iterator(self)
103

104
  @property
105
  def shape(self):
106
    return tuple(len(level) for level in self.levels)
107

108

109
@gin.configurable
110
class Backbone(NamedLists, Generic[T]):
111
  """A specific case of NamedList that is used in sequence alignments."""
112

113
  def __init__(self,
114
               embeddings = (),
115
               alignments = ()):
116
    super().__init__(layers=dict(embeddings=embeddings, alignments=alignments))
117

118
  @classmethod
119
  def constant_from_shape(cls, value, shape):
120
    return cls(
121
        embeddings=[value for _ in range(shape[0])],
122
        alignments=[value for _ in range(shape[1])])
123

124

125
@gin.configurable
126
class SwitchNamedLists(NamedLists[int]):
127
  """Provides methods to merge N compatible `NamedLists`.
128

129
  A `SwitchNamedLists` instance is a `NamedLists[int]` with values in [0, N)
130
  whose structure matches that of the desired merged `NamedLists` and elements
131
  indicate from which of the N input `NamedLists` the corresponding output value
132
  should be taken. That is,
133
    `output.key[l] = inputs[self.key[l]].key[l]`,
134
  where `inputs` is a sequence of N `NamedLists`.
135

136
  The N input `NamedLists` are assumed to be compatible in the sense that they
137
  have the same keys and the total number of elements they contain equals the
138
  number of elements in the `SwitchSeqAlign` instance. That is,
139
    `self.size == sum(inputs_i.size for inputs_i in inputs)`
140
  must hold true.
141
  """
142

143
  @property
144
  def n(self):
145
    """Returns the number of `NamedLists` being "switched over"."""
146
    return max(max(l) for l in self.values()) + 1  # Assumes elems in [0, n).
147

148
  def filter(self, inputs, i):
149
    """Removes elements from `NamedLists` not belonging to i-th input.
150

151
    Primarily used to remove "dummy" values e.g. from model output.
152

153
    Args:
154
      inputs: a `NamedLists` with structure identical to `self`.
155
      i: an int between 0 and N-1, both inclusive, where N is the number of
156
        `NamedLists` to be merged.
157

158
    Returns:
159
      A `NamedLists` defined as
160
        `output.key = [v for v, j in zip(inputs.key, self.key) if j == i]`.
161
      That is, for each key, only those elements in the list for which `self`
162
      takes value `i` at the matching position will be kept.
163
    """
164
    flags = self.get_selector(i)
165
    layers = {}
166
    for k in self.keys():
167
      layers[k] = [v for v, flag in zip(inputs[k], flags[k]) if flag]
168
    return NamedLists(layers)
169

170
  def merge(self, inputs):
171
    """Merges a sequence of N compatible `NamedLists`.
172

173
    Args:
174
      inputs: a sequence of N `NamedLists` with the same keys as `self`
175
        satisfying `self.size == sum(inputs_i.size for inputs_i in inputs)`.
176

177
    Returns:
178
      a `NamedLists` instance such that
179
        `output.key[l] = inputs[self.key[l]].key[l]`
180
      for each key in `self`.
181
    """
182
    inputs = [list(inputs_i) for inputs_i in inputs]
183
    offsets = len(inputs) * [0]
184
    outputs = []
185
    for i in list(self):  # Needed to appease AutoGraph?
186
      outputs.append(inputs[i][offsets[i]])  # pytype: disable=unsupported-operands  # trace-all-classes
187
      offsets[i] += 1  # pytype: disable=unsupported-operands  # trace-all-classes
188
    return self.pack(outputs)
189

190
  def merge_flattened(
191
      self,
192
      inputs,
193
      empty_value = None,
194
  ):
195
    """Merges a sequence of N compatible, flattened `NamedLists`.
196

197
    Args:
198
      inputs: a sequence of N `Mapping[str, T]` corresponding to N `NamedLists`
199
        that have been flattened. These must have the same keys as `self` and
200
        satisfy
201
          `self.size == sum(unflatten(inputs_i).size for inputs_i in inputs)`.
202
      empty_value: if provided, substitute `None`s by this value.
203

204
    Returns:
205
      a `NamedLists` instance such that
206
        `output.key[l] = inputs[self.key[l]].key[l]`
207
      for each key in `self`, flattened to `Mapping[str, T]`.
208
    """
209
    return self.merge([Backbone.unflatten(m_i) for m_i in inputs]).flatten(
210
        empty_value=empty_value)
211

212
  def get_selector(self, i):
213
    """Returns `NamedLists` of bools flagging elements from i-th input.
214

215
    Args:
216
      i: an int between 0 and N - 1, both inclusive, where N is the number of
217
        `NamedLists` to be merged.
218

219
    Returns:
220
      a `NamedLists[bool]` such that `output.key[l] = self.key[l] == i`.
221
    """
222
    return self.pack([j == i for j in self])
223

224

225
@gin.configurable
226
class SwitchBackbone(SwitchNamedLists):
227
  """A specific case of SwitchNamedLists that is used in sequence alignments."""
228

229
  def __init__(self,
230
               embeddings = (),
231
               alignments = ()):
232
    super().__init__(layers=dict(embeddings=embeddings, alignments=alignments))
233

234
  @classmethod
235
  def constant_like(cls, container, value = 0):
236
    return cls(
237
        embeddings=[value for _ in container.embeddings],
238
        alignments=[value for _ in container.alignments])
239

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.