google-research

Форк
0
298 строк · 9.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Various utilities that involve Tensorflow and accelerator devices.
17
"""
18
import logging
19
from typing import Dict, List, Optional, Tuple, Union  # pylint: disable=g-import-not-at-top
20

21
import colorama
22
import dataclasses
23
import numpy as np
24
import tensorflow as tf
25
import tensorflow.python.distribute.values as values
26
import tensorflow.python.eager.context as context
27
import tensorflow.python.framework.ops as ops
28
import tensorflow.python.tpu.topology as topology
29
import utils
30

31
LOGGER = logging.getLogger(__name__)
32

33

34
@dataclasses.dataclass
35
class TpuConfigType:
36
  resolver: tf.distribute.cluster_resolver.TPUClusterResolver
37
  topology: topology.Topology
38

39

40
@dataclasses.dataclass
41
class DevicesMapType:
42
  # pylint: disable=invalid-name
43
  TPUs: List[context.LogicalDevice]
44
  GPUs: List[context.LogicalDevice]
45
  CPUs: List[context.LogicalDevice]
46
  # pylint: enable=invalid-name
47

48

49
def init_tpus(tpu_name = None):
50
  """Initializes the connection with the TPUs."""
51
  try:
52
    if tpu_name:
53
      resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_name)
54
    else:
55
      resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
56
    if resolver:
57
      tf.config.experimental_connect_to_cluster(resolver)
58
      topology_ = tf.tpu.experimental.initialize_tpu_system(resolver)
59
      return TpuConfigType(resolver=resolver, topology=topology_)
60
  except ValueError:
61
    LOGGER.warning(
62
        "%(red_bg)s%(white)s%(bold)s WORKING WITH CPUS %(reset)s",
63
        dict(
64
            red_bg=colorama.Back.RED,
65
            white=colorama.Fore.WHITE,
66
            bold=colorama.Style.BRIGHT,
67
            reset=colorama.Style.RESET_ALL
68
            )
69
    )
70
    return None
71

72

73
def devices_to_use():
74
  """Returns the device objects for the accel. we are the most likely to use.
75

76
  Returns:
77
    List of logical devices of the accelerators we will use.
78
  """
79
  if tf.config.list_logical_devices("TPU"):
80
    devices = tf.config.list_logical_devices("TPU")
81
  elif tf.config.list_logical_devices("GPU"):
82
    devices = tf.config.list_logical_devices("GPU")
83
  else:
84
    devices = tf.config.list_logical_devices("CPU")
85
  devices.sort()
86
  return devices
87

88

89
def current_accelerator_type():
90
  """Returns the type of accelerator we are using.
91

92
  devices_to_use guaranties that all the accelerators are of the same type.
93
  """
94
  return devices_to_use()[0].device_type
95

96

97
def device_mapping():
98
  """Gives a dict with the different types of logical devices."""
99
  return DevicesMapType(
100
      TPUs=sorted(tf.config.list_logical_devices("TPU")),
101
      GPUs=sorted(tf.config.list_logical_devices("GPU")),
102
      CPUs=sorted(tf.config.list_logical_devices("CPU"))
103
  )
104

105

106
def make_dict_distribute_fn(batch):
107
  """Builds the dict distribution function."""
108
  def dict_distribute_fn(ctx):
109
    """Assumes all the tensors in the dict are of the same batch size."""
110
    quanta = len(next(iter(batch.values()))) // ctx.num_replicas_in_sync
111
    start = quanta * ctx.replica_id_in_sync_group
112
    end = start + quanta
113
    new = {}
114
    for k, v in batch.items():
115
      new[k] = v[start:end]
116
    return new
117
  return dict_distribute_fn
118

119

120
def deal_w_entry(strategy_outputs):
121
  output = strategy_outputs.values  # pytype: disable=attribute-error
122
  if isinstance(strategy_outputs, tuple):
123
    output = tf.concat(output, axis=0)
124
  return output
125

126

127
def process_strat_output(
128
    strategy_outputs,
129
    name,
130
    strategy,
131
    current_batch_size,
132
):
133
  """Uniformizes the different outputs of strategy.run calls."""
134
  if isinstance(strategy_outputs, values.PerReplica):
135
    strategy_outputs: values.PerReplica
136
    # LOGGER.debug("process_strat_output: %s: %s", name, str(strategy_outputs))
137
    output = deal_w_entry(strategy_outputs)
138
    utils.check_equal(output.shape, current_batch_size)
139
  elif (isinstance(strategy_outputs, tuple) and
140
        isinstance(strategy_outputs[0], values.PerReplica)):
141
    strategy_outputs: Tuple[values.PerReplica, Ellipsis]
142
    output = []
143
    for indiv_val in strategy_outputs:
144
      output.append(deal_w_entry(indiv_val))
145
    output = tuple(output)
146
  elif (isinstance(strategy_outputs, dict) and
147
        isinstance(next(iter(strategy_outputs.values())), values.PerReplica)):
148
    strategy_outputs: Dict[str, values.PerReplica]
149
    output = {}
150
    for k, indiv_val in strategy_outputs.items():
151
      output[k] = deal_w_entry(indiv_val)
152
  elif isinstance(
153
      strategy_outputs,
154
      ops.EagerTensor) or (isinstance(strategy_outputs, tuple) and
155
                           isinstance(strategy_outputs[0], ops.EagerTensor)):
156
    output = strategy_outputs
157
  else:
158
    raise RuntimeError(
159
        f"{name}: {type(strategy_outputs)}, {type(strategy)}"
160
    )
161

162
  return output
163

164

165
def load_reference_db(
166
    checkpoint_path, variable_name
167
):
168
  """Load the reference database for retrieval.
169

170
  This is mostly for compatibility with the REALM code.
171

172
  Args:
173
    checkpoint_path: The path of the checkpoint to use.
174
    variable_name: The variable name of the database inside of the checkpoint.
175

176
  Returns:
177
    A numpy array with the reference database.
178

179
  """
180
  ckpt = tf.train.load_checkpoint(str(checkpoint_path))
181
  try:
182
    reference_db = ckpt.get_tensor(variable_name)
183
  except tf.errors.NotFoundError:
184
    reference_db = ckpt.get_tensor(
185
        variable_name + "/.ATTRIBUTES/VARIABLE_VALUE")
186

187
  return reference_db
188

189

190
def mips_exact_search(
191
    vectors, num_neighbors, db
192
):
193
  """Does exact retrieval over a database.
194

195
  Args:
196
    vectors: The key vectors to retrieve with.
197
    num_neighbors: The number of neighbors to extract.
198
    db: The vector datase to retrieve from.
199

200
  Returns:
201
    top_k: The top_k indices, the retrieved neighbors.
202
    inners: The inner products of each neighbor.
203
  """
204
  product = tf.linalg.matmul(vectors, db, transpose_b=True)
205
  inners, top_k = tf.math.top_k(product, k=num_neighbors, sorted=sorted)
206
  return top_k, inners
207

208

209
def sample_without_replacement(logits, k):
210
  """Samples k values without replacement, from a set of logits.
211

212
  Courtesy of https://github.com/tensorflow/tensorflow/issues/9260#issuecomment-437875125  # pylint: disable=line-too-long
213
  and https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/  # pylint: disable=line-too-long
214

215
  Arguments:
216
    logits: The logits for the probabilities of the distribution.
217
    k: The number of samples to take.
218

219
  Returns:
220
    The indices of the values that were chose.
221
  """
222
  z = -tf.math.log(-tf.math.log(tf.random.uniform(tf.shape(logits), 0, 1)))
223
  _, indices = tf.nn.top_k(logits + z, k)
224
  return indices
225

226

227
@dataclasses.dataclass
228
class REALMSave:
229
  query_embedder_path: utils.PathType
230
  text_records: utils.PathType
231
  num_block_records: int
232
  description: str
233

234

235
# TODO(julesgm): This part needs work
236
# class InformationOnDevices:
237
#   """Information about the task to device distribution of the devices.
238
#
239
#   This doesn't make the assumption that each device has the same quantity of
240
#   tasks. Always true for TPUs. Maybe more resistant to weird configurations.
241
#   """
242
#
243
#   name_parse_pat = re.compile(
244
#       r"/job:worker/replica:0/task:([0-9]+)/device:(\w+):([0-9]+)"
245
#   )
246
#
247
#   def __init__(self):
248
#     self.devices_by_device_id = None
249
#     self.devices_by_task_id = None
250
#     self.num_tasks: int = 0
251
#     self.num_devices: int = 0
252
#     self.refresh()
253
#
254
#   def refresh(self) -> None:
255
#     """Refreshes the information.
256
#
257
#     Raises:
258
#       RuntimeError:
259
#         If one of the device names is in a format we can't parse.
260
#     """
261
#     devices_by_device_id = collections.defaultdict(list)
262
#     devices_by_task_id = collections.defaultdict(list)
263
#     for device in devices_to_use():
264
#       matches = self.name_parse_pat.match(device.name)
265
#       if matches is None:
266
#         raise RuntimeError(device.name)
267
#       task_no = int(matches.group(1))
268
#       device_no = int(matches.group(3))
269
#       devices_by_device_id[device_no].append((task_no, device))
270
#       devices_by_task_id[task_no].append((device_no, device))
271
#
272
#     LOGGER.debug("first devices_by_task_id:   %s", devices_by_task_id)
273
#     LOGGER.debug("first devices_by_device_id: %s", devices_by_device_id)
274
#
275
#     num_devices = len(devices_by_device_id)
276
#     num_tasks = len(devices_by_task_id)
277
#     LOGGER.debug("num_devices: %s", num_devices)
278
#     LOGGER.debug("num_tasks:   %s", num_tasks)
279
#
280
#     for k in devices_by_device_id:
281
#       devices_by_device_id[k].sort(key=lambda pair: pair[0])
282
#       # Remove the task no from the pair, as it is now equivalent to
283
#       # the position in the list.
284
#       devices_by_device_id[k] = [pair[1] for pair in devices_by_device_id[k]]
285
#
286
#     for k in devices_by_task_id:
287
#       devices_by_task_id[k].sort(key=lambda pair: pair[0])
288
#       # Remove the device no from the pair, as it is now equivalent to
289
#       # the position in the list.
290
#       devices_by_task_id[k] = [pair[1] for pair in devices_by_task_id[k]]
291
#
292
#     LOGGER.debug("second devices_by_task_id:   %s", devices_by_task_id)
293
#     LOGGER.debug("second devices_by_device_id: %s", devices_by_device_id)
294
#
295
#     self.devices_by_task_id = devices_by_task_id
296
#     self.devices_by_device_id = devices_by_device_id
297
#     self.num_devices = num_devices
298
#     self.num_tasks = num_tasks
299

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.