google-research
119 строк · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Provides `engine`: instance of `ComputeEngine` wraps TensorFlow functions."""
17from typing import Optional18
19import numpy as np20import tensorflow as tf21from sparse_deferred.implicit import matrix22
23Tensor = tf.Tensor24Shape = tf.TensorShape|list[int]25DType = tf.DType|str26
27
28class _TFEngine(matrix.ComputeEngine):29"""Implements tensorflow as a `ComputeEngine`."""30
31def where(self, condition, val_if_true,32val_if_false):33return tf.where(condition, val_if_true, val_if_false)34
35def assert_equal(self, tensor1, tensor2):36return tf.assert_equal(tensor1, tensor2)37
38def assert_greater(self, tensor1, tensor2):39return tf.assert_greater(tensor1, tensor2)40
41def ones(self, sizes, dtype = 'float32'):42return tf.ones(sizes, dtype=dtype)43
44def abs(self, tensor):45return tf.abs(tensor)46
47def rsqrt(self, tensor):48return tf.math.rsqrt(tensor)49
50def ones_like(self, tensor):51return tf.ones_like(tensor)52
53def transpose(self, tensor):54return tf.transpose(tensor)55
56def einsum(self, notation, a, b):57return tf.einsum(notation, a, b)58
59def add_n(self, tensors):60return tf.add_n(tensors)61
62def shape(self, tensor):63return tf.shape(tensor)64
65def eye(self, num_rows, dtype='float32'):66return tf.eye(num_rows, dtype=dtype)67
68def cast(self, tensor, dtype = 'float32'):69return tf.cast(tensor, dtype)70
71def minimum(self, x, y):72return tf.minimum(x, y)73
74def argsort(self, tensor, axis=-1, direction='ASCENDING'):75return tf.argsort(tensor, axis=axis, direction=direction)76
77def all(self, tensor, axis = None,78keepdims=False):79return tf.reduce_all(tensor, axis=axis, keepdims=keepdims)80
81def gather(self, tensor, indices, axis = 0):82return tf.gather(tensor, indices, axis=axis)83
84def unsorted_segment_sum(self, data, segment_ids,85num_segments):86return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)87
88def concat(self, tensors, axis):89return tf.concat(tensors, axis=axis)90
91def zeros(self, shape, dtype = 'float32'):92return tf.zeros(shape, dtype=dtype)93
94def reshape(self, tensor, shape):95return tf.reshape(tensor, shape)96
97def boolean_mask(self, tensor, mask):98return tf.boolean_mask(tensor, mask)99
100def reduce_all(self, tensor, axis = None,101keepdims = False):102return tf.reduce_all(tensor, axis=axis, keepdims=keepdims)103
104def reduce_any(self, tensor, axis = None,105keepdims = False):106return tf.reduce_any(tensor, axis=axis, keepdims=keepdims)107
108def maximum(self, x, y):109return tf.math.maximum(x, y)110
111def range(self, up_to, dtype = 'float32'):112return tf.range(up_to, dtype=dtype)113
114def to_cpu(self, tensor):115"""Brings a tensor to the CPU, so that python can access it."""116return np.array(tensor)117
118
119engine = _TFEngine()120