google-research

Форк
0
260 строк · 9.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests conversion TFGNN GraphTensor to GraphStruct with {tf, jax, np} engines.
17
"""
18

19
import jax
20
import tensorflow as tf
21
import tensorflow_gnn
22

23
from sparse_deferred import tfgnn as sdtfgnn
24
import sparse_deferred.jax as sdjax
25
import sparse_deferred.np as sdnp
26

27

28
_STUDENTS = [
29
    b'Sami',
30
    b'Bryan',
31
    b'Jonathan',
32
]
33
_STUDENT_IDX = {value: i for i, value in enumerate(_STUDENTS)}
34

35
_COURSES = [
36
    (b'Arabic', 101),
37
    (b'ML', 102),
38
    (b'English', 103),
39
    (b'Calculus', 104),
40
]
41
_COURSE_IDX = {value[0]: i for i, value in enumerate(_COURSES)}
42

43
_TOPICS = [
44
    b'NaturalLanguage',
45
    b'Math',
46
]
47
_TOPIC_IDX = {value: i for i, value in enumerate(_TOPICS)}
48

49
_COURSE_TOPICS = [
50
    (b'Arabic', b'NaturalLanguage'),
51
    (b'English', b'NaturalLanguage'),
52
    (b'ML', b'Math'),
53
    (b'Calculus', b'Math'),
54
]
55

56
_ENROLL_GRADES = {
57
    (b'Sami', b'Arabic'): 95,
58
    (b'Bryan', b'English'): 100,
59
    (b'Sami', b'ML'): 97,
60
    (b'Bryan', b'Calculus'): 99,
61
    (b'Bryan', b'ML'): 100,
62
    (b'Jonathan', b'Calculus'): 100,
63
    (b'Jonathan', b'Arabic'): 100,
64
}
65

66

67
def _make_tf_example():
68
  example = tf.train.Example()
69

70
  #### NODES
71

72
  # Add student nodes.
73
  example.features.feature['nodes/students.name'].bytes_list.value.extend(
74
      _STUDENTS)
75
  example.features.feature['nodes/students.#size'].int64_list.value.append(
76
      len(_STUDENTS))
77

78
  # Add course nodes.
79
  course_names, course_codes = zip(*_COURSES)
80
  example.features.feature['nodes/courses.name'].bytes_list.value.extend(
81
      course_names)
82
  example.features.feature['nodes/courses.code'].int64_list.value.extend(
83
      course_codes)
84
  example.features.feature['nodes/courses.#size'].int64_list.value.append(
85
      len(course_names))
86

87
  # Add topic nodes.
88
  example.features.feature['nodes/topics.name'].bytes_list.value.extend(
89
      _TOPICS)
90
  example.features.feature['nodes/topics.#size'].int64_list.value.append(
91
      len(_TOPICS))
92

93
  #### EDGES
94
  example.features.feature['edges/has_topic.#source'].int64_list.value.extend(
95
      [_COURSE_IDX[course] for (course, _) in _COURSE_TOPICS])
96
  example.features.feature['edges/has_topic.#target'].int64_list.value.extend(
97
      [_TOPIC_IDX[topic] for (_, topic) in _COURSE_TOPICS])
98
  example.features.feature['edges/has_topic.#size'].int64_list.value.append(
99
      len(_COURSE_TOPICS))
100

101
  enrollments, grades = zip(*_ENROLL_GRADES.items())
102
  enrollment_students, enrollment_courses = zip(*enrollments)
103

104
  example.features.feature['edges/enrollments.#source'].int64_list.value.extend(
105
      [_STUDENT_IDX[student] for student in enrollment_students])
106
  example.features.feature['edges/enrollments.#target'].int64_list.value.extend(
107
      [_COURSE_IDX[course] for  course in enrollment_courses])
108
  example.features.feature['edges/enrollments.grade'].float_list.value.extend(
109
      grades)
110
  example.features.feature['edges/enrollments.#size'].int64_list.value.append(
111
      len(enrollment_courses))
112

113
  #### CONTEXT
114
  example.features.feature['context/root_node'].int64_list.value.append(1)
115

116
  return example
117

118

119
def _make_schema():
120
  schema = tensorflow_gnn.GraphSchema()
121
  schema.node_sets['students'].features['name'].dtype = (
122
      tf.string.as_datatype_enum)
123
  schema.node_sets['courses'].features['name'].dtype = (
124
      tf.string.as_datatype_enum)
125
  schema.node_sets['courses'].features['code'].dtype = (
126
      tf.int32.as_datatype_enum)
127
  schema.node_sets['courses'].features['code'].shape.dim.add().size = 1
128
  schema.node_sets['topics'].features['name'].dtype = (
129
      tf.string.as_datatype_enum)
130

131
  schema.edge_sets['has_topic'].source = 'courses'
132
  schema.edge_sets['has_topic'].target = 'topics'
133

134
  schema.edge_sets['enrollments'].source = 'students'
135
  schema.edge_sets['enrollments'].target = 'courses'
136
  schema.edge_sets['enrollments'].features['grade'].dtype = (
137
      tf.float32.as_datatype_enum)
138
  return schema
139

140

141
class _BaseIOTest(tf.test.TestCase):
142
  """Tests for io.py, when using TensorFlow as a backend."""
143

144
  def _assert_correct(self, graph):
145
    # Assert nodes are correct.
146
    self.assertAllEqual(graph.nodes['students']['name'], _STUDENTS)
147
    course_names, course_codes = zip(*_COURSES)
148
    self.assertAllEqual(graph.nodes['courses']['name'], course_names)
149
    self.assertAllEqual(
150
        graph.nodes['courses']['code'],
151
        # _make_tf_example adds dimension.
152
        tf.expand_dims(course_codes, -1))
153
    self.assertAllEqual(graph.nodes['topics']['name'], _TOPICS)
154

155
    (src, tgt), features = graph.edges['has_topic']
156
    self.assertEmpty(features)  # No features for has_topics!
157
    has_topic_edges = set(zip(tf.gather(course_names, src).numpy(),
158
                              tf.gather(_TOPICS, tgt).numpy()))
159
    self.assertSetEqual(has_topic_edges, set(_COURSE_TOPICS))
160

161
    (src, tgt), features = graph.edges['enrollments']
162
    src_student_names = tf.gather(_STUDENTS, src).numpy()
163
    tgt_course_names = tf.gather(course_names, tgt).numpy()
164
    enrollment_edges = set(zip(src_student_names, tgt_course_names))
165
    self.assertSetEqual(enrollment_edges, set(_ENROLL_GRADES.keys()))
166
    self.assertIn('grade', features)
167

168
    for student, course, grade in zip(
169
        src_student_names, tgt_course_names, features['grade']):
170
      self.assertAllEqual(grade, _ENROLL_GRADES[(student, course)])
171

172

173
class TensorflowIOTest(_BaseIOTest):
174

175
  def test_graph_struct_from_tf_example(self):
176
    tf_example = _make_tf_example()
177
    schema = _make_schema()
178
    graph_struct = sdtfgnn.graph_struct_from_tf_example(tf_example, schema)
179
    self._assert_correct(graph_struct)
180

181
  def test_graph_struct_from_tfgnn_graph_tensor(self):
182
    tf_example = _make_tf_example()
183
    schema = _make_schema()
184
    graph_spec = tensorflow_gnn.create_graph_spec_from_schema_pb(schema)
185
    graph_tensor = tensorflow_gnn.parse_single_example(
186
        graph_spec, tf_example.SerializeToString())
187
    graph_struct = sdtfgnn.graph_struct_from_graph_tensor(graph_tensor)
188
    self._assert_correct(graph_struct)
189

190

191
class NumpyIOTest(_BaseIOTest):
192

193
  def test_graph_struct_from_tf_example(self):
194
    tf_example = _make_tf_example()
195
    schema = _make_schema()
196
    graph_struct = sdtfgnn.graph_struct_from_tf_example(
197
        tf_example, schema, engine=sdnp.engine)
198
    self._assert_correct(graph_struct)
199

200
  def test_graph_struct_from_tfgnn_graph_tensor(self):
201
    tf_example = _make_tf_example()
202
    schema = _make_schema()
203
    graph_spec = tensorflow_gnn.create_graph_spec_from_schema_pb(schema)
204
    graph_tensor = tensorflow_gnn.parse_single_example(
205
        graph_spec, tf_example.SerializeToString())
206
    graph_struct = sdtfgnn.graph_struct_from_graph_tensor(
207
        graph_tensor, engine=sdnp.engine)
208
    self._assert_correct(graph_struct)
209

210

211
class JaxIOTest(tf.test.TestCase):
212
  """Jax does not support string features, therefore a (simple) modified graph.
213

214
  https://github.com/google/jax/issues/2084
215
  """
216

217
  def test_graph_struct_from_graph_tensor(self):
218
    graph_tensor = tensorflow_gnn.GraphTensor.from_pieces(
219
        context=tensorflow_gnn.Context.from_fields(
220
            features={'x': tf.constant([1])}),
221
        node_sets={
222
            'n1': tensorflow_gnn.NodeSet.from_fields(
223
                features={'f1': tf.constant([2, 3, 4])},
224
                sizes=tf.constant([3])),
225
            'n2': tensorflow_gnn.NodeSet.from_fields(
226
                features={'f2': tf.constant([[-1.0], [-2.0]])},
227
                sizes=tf.constant([2])),
228
        },
229
        edge_sets={
230
            'e': tensorflow_gnn.EdgeSet.from_fields(
231
                adjacency=tensorflow_gnn.Adjacency.from_indices(
232
                    source=('n1', tf.constant([0, 0, 1, 2])),
233
                    target=('n2', tf.constant([0, 1, 0, 1]))),
234
                features={'f': tf.constant([1, 2, 3, 4])},
235
                sizes=tf.constant([4])),
236
        }
237
    )
238
    graph_struct = sdtfgnn.graph_struct_from_graph_tensor(
239
        graph_tensor, engine=sdjax.engine)
240
    self.assertAllEqual(graph_struct.schema['e'], ('n1', 'n2'))
241
    self.assertAllEqual(graph_struct.nodes['n1']['f1'], [2, 3, 4])
242
    self.assertAllEqual(graph_struct.nodes['n2']['f2'], [[-1.0], [-2.0]])
243
    edge_endpoints, edge_features = graph_struct.edges['e']
244
    self.assertAllEqual(edge_features['f'], [1, 2, 3, 4])
245
    self.assertAllEqual(edge_endpoints[0], [0, 0, 1, 2])
246
    self.assertAllEqual(edge_endpoints[1], [0, 1, 0, 1])
247

248
    self.assertIsInstance(edge_endpoints[0], jax.Array)
249
    self.assertIsInstance(edge_endpoints[1], jax.Array)
250
    self.assertIsInstance(edge_features['f'], jax.Array)
251
    num_processed_features = 0
252
    for features in graph_struct.nodes.values():
253
      for feature_value in features.values():
254
        num_processed_features += 1
255
        self.assertIsInstance(feature_value, jax.Array)
256
    self.assertEqual(num_processed_features, 3)  # two node + one graph feats.
257

258

259
if __name__ == '__main__':
260
  tf.test.main()
261

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.