google-research
238 строк · 7.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generic Multi task architecture."""
17
18import copy19from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, Generic20
21import gin22
23
24T = TypeVar('T')25
26
27class NamedLists(dict, Generic[T]):28"""A generic architecture for multi tasks with potentially several levels."""29
30def __init__(self, layers):31layers = {k: list(v) for k, v in layers.items()}32super().__init__(layers)33
34def __getattr__(self, attr):35return self[attr]36
37@property38def levels(self):39return list(self.values())40
41@property42def size(self):43return sum(len(x) for x in self.values())44
45def constant_copy(self, value):46"""Returns a copy of the structure with only the same value everywhere."""47return NamedLists(48layers={k: [value for _ in v] for k, v in self.items()})49
50def copy(self):51"""Returns a copy of the NamedLists."""52return NamedLists(copy.deepcopy(super().copy()))53
54def pack(self, values, default_value=None):55"""Packs the values in a NamedLists with the same structure as self."""56result = self.constant_copy(default_value)57it = result.__iter__()58for val in values:59next(it) # pytype: disable=wrong-arg-types # dynamic-method-lookup60it._level[it._idx] = val # pylint: disable=protected-access # pytype: disable=unsupported-operands # dynamic-method-lookup61return result62
63def flatten(self, empty_value=None):64result = {}65for name, values in self.items():66for i, value in enumerate(values):67result[f'{name}/{i}'] = value68if not values: # special case for empty list to keep the structure.69result[name + '/'] = empty_value70return result71
72@staticmethod73def unflatten(values):74"""Unflatten a dict of values that have been previously flattened."""75result = dict()76for name, value in values.items():77idx = name.rfind('/')78key = name[:idx]79if key not in result:80result[key] = []81if idx != len(name) - 1:82result[key].append(value)83return NamedLists(result)84
85class _Iterator:86"""Iterator on NamedLists."""87
88def __init__(self, container):89self._level_iter = iter(container.values())90self._level = None91self._idx = -192
93def __next__(self):94self._idx += 195if self._level is None or self._idx >= len(self._level):96self._level = next(self._level_iter) # Might raise StopIteration here.97self._idx = -198return self.__next__()99return self._level[self._idx]100
101def __iter__(self):102return NamedLists._Iterator(self)103
104@property105def shape(self):106return tuple(len(level) for level in self.levels)107
108
109@gin.configurable110class Backbone(NamedLists, Generic[T]):111"""A specific case of NamedList that is used in sequence alignments."""112
113def __init__(self,114embeddings = (),115alignments = ()):116super().__init__(layers=dict(embeddings=embeddings, alignments=alignments))117
118@classmethod119def constant_from_shape(cls, value, shape):120return cls(121embeddings=[value for _ in range(shape[0])],122alignments=[value for _ in range(shape[1])])123
124
125@gin.configurable126class SwitchNamedLists(NamedLists[int]):127"""Provides methods to merge N compatible `NamedLists`.128
129A `SwitchNamedLists` instance is a `NamedLists[int]` with values in [0, N)
130whose structure matches that of the desired merged `NamedLists` and elements
131indicate from which of the N input `NamedLists` the corresponding output value
132should be taken. That is,
133`output.key[l] = inputs[self.key[l]].key[l]`,
134where `inputs` is a sequence of N `NamedLists`.
135
136The N input `NamedLists` are assumed to be compatible in the sense that they
137have the same keys and the total number of elements they contain equals the
138number of elements in the `SwitchSeqAlign` instance. That is,
139`self.size == sum(inputs_i.size for inputs_i in inputs)`
140must hold true.
141"""
142
143@property144def n(self):145"""Returns the number of `NamedLists` being "switched over"."""146return max(max(l) for l in self.values()) + 1 # Assumes elems in [0, n).147
148def filter(self, inputs, i):149"""Removes elements from `NamedLists` not belonging to i-th input.150
151Primarily used to remove "dummy" values e.g. from model output.
152
153Args:
154inputs: a `NamedLists` with structure identical to `self`.
155i: an int between 0 and N-1, both inclusive, where N is the number of
156`NamedLists` to be merged.
157
158Returns:
159A `NamedLists` defined as
160`output.key = [v for v, j in zip(inputs.key, self.key) if j == i]`.
161That is, for each key, only those elements in the list for which `self`
162takes value `i` at the matching position will be kept.
163"""
164flags = self.get_selector(i)165layers = {}166for k in self.keys():167layers[k] = [v for v, flag in zip(inputs[k], flags[k]) if flag]168return NamedLists(layers)169
170def merge(self, inputs):171"""Merges a sequence of N compatible `NamedLists`.172
173Args:
174inputs: a sequence of N `NamedLists` with the same keys as `self`
175satisfying `self.size == sum(inputs_i.size for inputs_i in inputs)`.
176
177Returns:
178a `NamedLists` instance such that
179`output.key[l] = inputs[self.key[l]].key[l]`
180for each key in `self`.
181"""
182inputs = [list(inputs_i) for inputs_i in inputs]183offsets = len(inputs) * [0]184outputs = []185for i in list(self): # Needed to appease AutoGraph?186outputs.append(inputs[i][offsets[i]]) # pytype: disable=unsupported-operands # trace-all-classes187offsets[i] += 1 # pytype: disable=unsupported-operands # trace-all-classes188return self.pack(outputs)189
190def merge_flattened(191self,192inputs,193empty_value = None,194):195"""Merges a sequence of N compatible, flattened `NamedLists`.196
197Args:
198inputs: a sequence of N `Mapping[str, T]` corresponding to N `NamedLists`
199that have been flattened. These must have the same keys as `self` and
200satisfy
201`self.size == sum(unflatten(inputs_i).size for inputs_i in inputs)`.
202empty_value: if provided, substitute `None`s by this value.
203
204Returns:
205a `NamedLists` instance such that
206`output.key[l] = inputs[self.key[l]].key[l]`
207for each key in `self`, flattened to `Mapping[str, T]`.
208"""
209return self.merge([Backbone.unflatten(m_i) for m_i in inputs]).flatten(210empty_value=empty_value)211
212def get_selector(self, i):213"""Returns `NamedLists` of bools flagging elements from i-th input.214
215Args:
216i: an int between 0 and N - 1, both inclusive, where N is the number of
217`NamedLists` to be merged.
218
219Returns:
220a `NamedLists[bool]` such that `output.key[l] = self.key[l] == i`.
221"""
222return self.pack([j == i for j in self])223
224
225@gin.configurable226class SwitchBackbone(SwitchNamedLists):227"""A specific case of SwitchNamedLists that is used in sequence alignments."""228
229def __init__(self,230embeddings = (),231alignments = ()):232super().__init__(layers=dict(embeddings=embeddings, alignments=alignments))233
234@classmethod235def constant_like(cls, container, value = 0):236return cls(237embeddings=[value for _ in container.embeddings],238alignments=[value for _ in container.alignments])239