google-research

Форк
0
119 строк · 3.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Provides `engine`: instance of `ComputeEngine` wraps TensorFlow functions."""
17
from typing import Optional
18

19
import numpy as np
20
import tensorflow as tf
21
from sparse_deferred.implicit import matrix
22

23
Tensor = tf.Tensor
24
Shape = tf.TensorShape|list[int]
25
DType = tf.DType|str
26

27

28
class _TFEngine(matrix.ComputeEngine):
29
  """Implements tensorflow as a `ComputeEngine`."""
30

31
  def where(self, condition, val_if_true,
32
            val_if_false):
33
    return tf.where(condition, val_if_true, val_if_false)
34

35
  def assert_equal(self, tensor1, tensor2):
36
    return tf.assert_equal(tensor1, tensor2)
37

38
  def assert_greater(self, tensor1, tensor2):
39
    return tf.assert_greater(tensor1, tensor2)
40

41
  def ones(self, sizes, dtype = 'float32'):
42
    return tf.ones(sizes, dtype=dtype)
43

44
  def abs(self, tensor):
45
    return tf.abs(tensor)
46

47
  def rsqrt(self, tensor):
48
    return tf.math.rsqrt(tensor)
49

50
  def ones_like(self, tensor):
51
    return tf.ones_like(tensor)
52

53
  def transpose(self, tensor):
54
    return tf.transpose(tensor)
55

56
  def einsum(self, notation, a, b):
57
    return tf.einsum(notation, a, b)
58

59
  def add_n(self, tensors):
60
    return tf.add_n(tensors)
61

62
  def shape(self, tensor):
63
    return tf.shape(tensor)
64

65
  def eye(self, num_rows, dtype='float32'):
66
    return tf.eye(num_rows, dtype=dtype)
67

68
  def cast(self, tensor, dtype = 'float32'):
69
    return tf.cast(tensor, dtype)
70

71
  def minimum(self, x, y):
72
    return tf.minimum(x, y)
73

74
  def argsort(self, tensor, axis=-1, direction='ASCENDING'):
75
    return tf.argsort(tensor, axis=axis, direction=direction)
76

77
  def all(self, tensor, axis = None,
78
          keepdims=False):
79
    return tf.reduce_all(tensor, axis=axis, keepdims=keepdims)
80

81
  def gather(self, tensor, indices, axis = 0):
82
    return tf.gather(tensor, indices, axis=axis)
83

84
  def unsorted_segment_sum(self, data, segment_ids,
85
                           num_segments):
86
    return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)
87

88
  def concat(self, tensors, axis):
89
    return tf.concat(tensors, axis=axis)
90

91
  def zeros(self, shape, dtype = 'float32'):
92
    return tf.zeros(shape, dtype=dtype)
93

94
  def reshape(self, tensor, shape):
95
    return tf.reshape(tensor, shape)
96

97
  def boolean_mask(self, tensor, mask):
98
    return tf.boolean_mask(tensor, mask)
99

100
  def reduce_all(self, tensor, axis = None,
101
                 keepdims = False):
102
    return tf.reduce_all(tensor, axis=axis, keepdims=keepdims)
103

104
  def reduce_any(self, tensor, axis = None,
105
                 keepdims = False):
106
    return tf.reduce_any(tensor, axis=axis, keepdims=keepdims)
107

108
  def maximum(self, x, y):
109
    return tf.math.maximum(x, y)
110

111
  def range(self, up_to, dtype = 'float32'):
112
    return tf.range(up_to, dtype=dtype)
113

114
  def to_cpu(self, tensor):
115
    """Brings a tensor to the CPU, so that python can access it."""
116
    return np.array(tensor)
117

118

119
engine = _TFEngine()
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.