google-research
298 строк · 9.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Various utilities that involve Tensorflow and accelerator devices.
17"""
18import logging19from typing import Dict, List, Optional, Tuple, Union # pylint: disable=g-import-not-at-top20
21import colorama22import dataclasses23import numpy as np24import tensorflow as tf25import tensorflow.python.distribute.values as values26import tensorflow.python.eager.context as context27import tensorflow.python.framework.ops as ops28import tensorflow.python.tpu.topology as topology29import utils30
31LOGGER = logging.getLogger(__name__)32
33
34@dataclasses.dataclass35class TpuConfigType:36resolver: tf.distribute.cluster_resolver.TPUClusterResolver37topology: topology.Topology38
39
40@dataclasses.dataclass41class DevicesMapType:42# pylint: disable=invalid-name43TPUs: List[context.LogicalDevice]44GPUs: List[context.LogicalDevice]45CPUs: List[context.LogicalDevice]46# pylint: enable=invalid-name47
48
49def init_tpus(tpu_name = None):50"""Initializes the connection with the TPUs."""51try:52if tpu_name:53resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_name)54else:55resolver = tf.distribute.cluster_resolver.TPUClusterResolver()56if resolver:57tf.config.experimental_connect_to_cluster(resolver)58topology_ = tf.tpu.experimental.initialize_tpu_system(resolver)59return TpuConfigType(resolver=resolver, topology=topology_)60except ValueError:61LOGGER.warning(62"%(red_bg)s%(white)s%(bold)s WORKING WITH CPUS %(reset)s",63dict(64red_bg=colorama.Back.RED,65white=colorama.Fore.WHITE,66bold=colorama.Style.BRIGHT,67reset=colorama.Style.RESET_ALL68)69)70return None71
72
73def devices_to_use():74"""Returns the device objects for the accel. we are the most likely to use.75
76Returns:
77List of logical devices of the accelerators we will use.
78"""
79if tf.config.list_logical_devices("TPU"):80devices = tf.config.list_logical_devices("TPU")81elif tf.config.list_logical_devices("GPU"):82devices = tf.config.list_logical_devices("GPU")83else:84devices = tf.config.list_logical_devices("CPU")85devices.sort()86return devices87
88
89def current_accelerator_type():90"""Returns the type of accelerator we are using.91
92devices_to_use guaranties that all the accelerators are of the same type.
93"""
94return devices_to_use()[0].device_type95
96
97def device_mapping():98"""Gives a dict with the different types of logical devices."""99return DevicesMapType(100TPUs=sorted(tf.config.list_logical_devices("TPU")),101GPUs=sorted(tf.config.list_logical_devices("GPU")),102CPUs=sorted(tf.config.list_logical_devices("CPU"))103)104
105
106def make_dict_distribute_fn(batch):107"""Builds the dict distribution function."""108def dict_distribute_fn(ctx):109"""Assumes all the tensors in the dict are of the same batch size."""110quanta = len(next(iter(batch.values()))) // ctx.num_replicas_in_sync111start = quanta * ctx.replica_id_in_sync_group112end = start + quanta113new = {}114for k, v in batch.items():115new[k] = v[start:end]116return new117return dict_distribute_fn118
119
120def deal_w_entry(strategy_outputs):121output = strategy_outputs.values # pytype: disable=attribute-error122if isinstance(strategy_outputs, tuple):123output = tf.concat(output, axis=0)124return output125
126
127def process_strat_output(128strategy_outputs,129name,130strategy,131current_batch_size,132):133"""Uniformizes the different outputs of strategy.run calls."""134if isinstance(strategy_outputs, values.PerReplica):135strategy_outputs: values.PerReplica136# LOGGER.debug("process_strat_output: %s: %s", name, str(strategy_outputs))137output = deal_w_entry(strategy_outputs)138utils.check_equal(output.shape, current_batch_size)139elif (isinstance(strategy_outputs, tuple) and140isinstance(strategy_outputs[0], values.PerReplica)):141strategy_outputs: Tuple[values.PerReplica, Ellipsis]142output = []143for indiv_val in strategy_outputs:144output.append(deal_w_entry(indiv_val))145output = tuple(output)146elif (isinstance(strategy_outputs, dict) and147isinstance(next(iter(strategy_outputs.values())), values.PerReplica)):148strategy_outputs: Dict[str, values.PerReplica]149output = {}150for k, indiv_val in strategy_outputs.items():151output[k] = deal_w_entry(indiv_val)152elif isinstance(153strategy_outputs,154ops.EagerTensor) or (isinstance(strategy_outputs, tuple) and155isinstance(strategy_outputs[0], ops.EagerTensor)):156output = strategy_outputs157else:158raise RuntimeError(159f"{name}: {type(strategy_outputs)}, {type(strategy)}"160)161
162return output163
164
165def load_reference_db(166checkpoint_path, variable_name167):168"""Load the reference database for retrieval.169
170This is mostly for compatibility with the REALM code.
171
172Args:
173checkpoint_path: The path of the checkpoint to use.
174variable_name: The variable name of the database inside of the checkpoint.
175
176Returns:
177A numpy array with the reference database.
178
179"""
180ckpt = tf.train.load_checkpoint(str(checkpoint_path))181try:182reference_db = ckpt.get_tensor(variable_name)183except tf.errors.NotFoundError:184reference_db = ckpt.get_tensor(185variable_name + "/.ATTRIBUTES/VARIABLE_VALUE")186
187return reference_db188
189
190def mips_exact_search(191vectors, num_neighbors, db192):193"""Does exact retrieval over a database.194
195Args:
196vectors: The key vectors to retrieve with.
197num_neighbors: The number of neighbors to extract.
198db: The vector datase to retrieve from.
199
200Returns:
201top_k: The top_k indices, the retrieved neighbors.
202inners: The inner products of each neighbor.
203"""
204product = tf.linalg.matmul(vectors, db, transpose_b=True)205inners, top_k = tf.math.top_k(product, k=num_neighbors, sorted=sorted)206return top_k, inners207
208
209def sample_without_replacement(logits, k):210"""Samples k values without replacement, from a set of logits.211
212Courtesy of https://github.com/tensorflow/tensorflow/issues/9260#issuecomment-437875125 # pylint: disable=line-too-long
213and https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/ # pylint: disable=line-too-long
214
215Arguments:
216logits: The logits for the probabilities of the distribution.
217k: The number of samples to take.
218
219Returns:
220The indices of the values that were chose.
221"""
222z = -tf.math.log(-tf.math.log(tf.random.uniform(tf.shape(logits), 0, 1)))223_, indices = tf.nn.top_k(logits + z, k)224return indices225
226
227@dataclasses.dataclass228class REALMSave:229query_embedder_path: utils.PathType230text_records: utils.PathType231num_block_records: int232description: str233
234
235# TODO(julesgm): This part needs work
236# class InformationOnDevices:
237# """Information about the task to device distribution of the devices.
238#
239# This doesn't make the assumption that each device has the same quantity of
240# tasks. Always true for TPUs. Maybe more resistant to weird configurations.
241# """
242#
243# name_parse_pat = re.compile(
244# r"/job:worker/replica:0/task:([0-9]+)/device:(\w+):([0-9]+)"
245# )
246#
247# def __init__(self):
248# self.devices_by_device_id = None
249# self.devices_by_task_id = None
250# self.num_tasks: int = 0
251# self.num_devices: int = 0
252# self.refresh()
253#
254# def refresh(self) -> None:
255# """Refreshes the information.
256#
257# Raises:
258# RuntimeError:
259# If one of the device names is in a format we can't parse.
260# """
261# devices_by_device_id = collections.defaultdict(list)
262# devices_by_task_id = collections.defaultdict(list)
263# for device in devices_to_use():
264# matches = self.name_parse_pat.match(device.name)
265# if matches is None:
266# raise RuntimeError(device.name)
267# task_no = int(matches.group(1))
268# device_no = int(matches.group(3))
269# devices_by_device_id[device_no].append((task_no, device))
270# devices_by_task_id[task_no].append((device_no, device))
271#
272# LOGGER.debug("first devices_by_task_id: %s", devices_by_task_id)
273# LOGGER.debug("first devices_by_device_id: %s", devices_by_device_id)
274#
275# num_devices = len(devices_by_device_id)
276# num_tasks = len(devices_by_task_id)
277# LOGGER.debug("num_devices: %s", num_devices)
278# LOGGER.debug("num_tasks: %s", num_tasks)
279#
280# for k in devices_by_device_id:
281# devices_by_device_id[k].sort(key=lambda pair: pair[0])
282# # Remove the task no from the pair, as it is now equivalent to
283# # the position in the list.
284# devices_by_device_id[k] = [pair[1] for pair in devices_by_device_id[k]]
285#
286# for k in devices_by_task_id:
287# devices_by_task_id[k].sort(key=lambda pair: pair[0])
288# # Remove the device no from the pair, as it is now equivalent to
289# # the position in the list.
290# devices_by_task_id[k] = [pair[1] for pair in devices_by_task_id[k]]
291#
292# LOGGER.debug("second devices_by_task_id: %s", devices_by_task_id)
293# LOGGER.debug("second devices_by_device_id: %s", devices_by_device_id)
294#
295# self.devices_by_task_id = devices_by_task_id
296# self.devices_by_device_id = devices_by_device_id
297# self.num_devices = num_devices
298# self.num_tasks = num_tasks
299