google-research
260 строк · 9.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests conversion TFGNN GraphTensor to GraphStruct with {tf, jax, np} engines.
17"""
18
19import jax
20import tensorflow as tf
21import tensorflow_gnn
22
23from sparse_deferred import tfgnn as sdtfgnn
24import sparse_deferred.jax as sdjax
25import sparse_deferred.np as sdnp
26
27
28_STUDENTS = [
29b'Sami',
30b'Bryan',
31b'Jonathan',
32]
33_STUDENT_IDX = {value: i for i, value in enumerate(_STUDENTS)}
34
35_COURSES = [
36(b'Arabic', 101),
37(b'ML', 102),
38(b'English', 103),
39(b'Calculus', 104),
40]
41_COURSE_IDX = {value[0]: i for i, value in enumerate(_COURSES)}
42
43_TOPICS = [
44b'NaturalLanguage',
45b'Math',
46]
47_TOPIC_IDX = {value: i for i, value in enumerate(_TOPICS)}
48
49_COURSE_TOPICS = [
50(b'Arabic', b'NaturalLanguage'),
51(b'English', b'NaturalLanguage'),
52(b'ML', b'Math'),
53(b'Calculus', b'Math'),
54]
55
56_ENROLL_GRADES = {
57(b'Sami', b'Arabic'): 95,
58(b'Bryan', b'English'): 100,
59(b'Sami', b'ML'): 97,
60(b'Bryan', b'Calculus'): 99,
61(b'Bryan', b'ML'): 100,
62(b'Jonathan', b'Calculus'): 100,
63(b'Jonathan', b'Arabic'): 100,
64}
65
66
67def _make_tf_example():
68example = tf.train.Example()
69
70#### NODES
71
72# Add student nodes.
73example.features.feature['nodes/students.name'].bytes_list.value.extend(
74_STUDENTS)
75example.features.feature['nodes/students.#size'].int64_list.value.append(
76len(_STUDENTS))
77
78# Add course nodes.
79course_names, course_codes = zip(*_COURSES)
80example.features.feature['nodes/courses.name'].bytes_list.value.extend(
81course_names)
82example.features.feature['nodes/courses.code'].int64_list.value.extend(
83course_codes)
84example.features.feature['nodes/courses.#size'].int64_list.value.append(
85len(course_names))
86
87# Add topic nodes.
88example.features.feature['nodes/topics.name'].bytes_list.value.extend(
89_TOPICS)
90example.features.feature['nodes/topics.#size'].int64_list.value.append(
91len(_TOPICS))
92
93#### EDGES
94example.features.feature['edges/has_topic.#source'].int64_list.value.extend(
95[_COURSE_IDX[course] for (course, _) in _COURSE_TOPICS])
96example.features.feature['edges/has_topic.#target'].int64_list.value.extend(
97[_TOPIC_IDX[topic] for (_, topic) in _COURSE_TOPICS])
98example.features.feature['edges/has_topic.#size'].int64_list.value.append(
99len(_COURSE_TOPICS))
100
101enrollments, grades = zip(*_ENROLL_GRADES.items())
102enrollment_students, enrollment_courses = zip(*enrollments)
103
104example.features.feature['edges/enrollments.#source'].int64_list.value.extend(
105[_STUDENT_IDX[student] for student in enrollment_students])
106example.features.feature['edges/enrollments.#target'].int64_list.value.extend(
107[_COURSE_IDX[course] for course in enrollment_courses])
108example.features.feature['edges/enrollments.grade'].float_list.value.extend(
109grades)
110example.features.feature['edges/enrollments.#size'].int64_list.value.append(
111len(enrollment_courses))
112
113#### CONTEXT
114example.features.feature['context/root_node'].int64_list.value.append(1)
115
116return example
117
118
119def _make_schema():
120schema = tensorflow_gnn.GraphSchema()
121schema.node_sets['students'].features['name'].dtype = (
122tf.string.as_datatype_enum)
123schema.node_sets['courses'].features['name'].dtype = (
124tf.string.as_datatype_enum)
125schema.node_sets['courses'].features['code'].dtype = (
126tf.int32.as_datatype_enum)
127schema.node_sets['courses'].features['code'].shape.dim.add().size = 1
128schema.node_sets['topics'].features['name'].dtype = (
129tf.string.as_datatype_enum)
130
131schema.edge_sets['has_topic'].source = 'courses'
132schema.edge_sets['has_topic'].target = 'topics'
133
134schema.edge_sets['enrollments'].source = 'students'
135schema.edge_sets['enrollments'].target = 'courses'
136schema.edge_sets['enrollments'].features['grade'].dtype = (
137tf.float32.as_datatype_enum)
138return schema
139
140
141class _BaseIOTest(tf.test.TestCase):
142"""Tests for io.py, when using TensorFlow as a backend."""
143
144def _assert_correct(self, graph):
145# Assert nodes are correct.
146self.assertAllEqual(graph.nodes['students']['name'], _STUDENTS)
147course_names, course_codes = zip(*_COURSES)
148self.assertAllEqual(graph.nodes['courses']['name'], course_names)
149self.assertAllEqual(
150graph.nodes['courses']['code'],
151# _make_tf_example adds dimension.
152tf.expand_dims(course_codes, -1))
153self.assertAllEqual(graph.nodes['topics']['name'], _TOPICS)
154
155(src, tgt), features = graph.edges['has_topic']
156self.assertEmpty(features) # No features for has_topics!
157has_topic_edges = set(zip(tf.gather(course_names, src).numpy(),
158tf.gather(_TOPICS, tgt).numpy()))
159self.assertSetEqual(has_topic_edges, set(_COURSE_TOPICS))
160
161(src, tgt), features = graph.edges['enrollments']
162src_student_names = tf.gather(_STUDENTS, src).numpy()
163tgt_course_names = tf.gather(course_names, tgt).numpy()
164enrollment_edges = set(zip(src_student_names, tgt_course_names))
165self.assertSetEqual(enrollment_edges, set(_ENROLL_GRADES.keys()))
166self.assertIn('grade', features)
167
168for student, course, grade in zip(
169src_student_names, tgt_course_names, features['grade']):
170self.assertAllEqual(grade, _ENROLL_GRADES[(student, course)])
171
172
173class TensorflowIOTest(_BaseIOTest):
174
175def test_graph_struct_from_tf_example(self):
176tf_example = _make_tf_example()
177schema = _make_schema()
178graph_struct = sdtfgnn.graph_struct_from_tf_example(tf_example, schema)
179self._assert_correct(graph_struct)
180
181def test_graph_struct_from_tfgnn_graph_tensor(self):
182tf_example = _make_tf_example()
183schema = _make_schema()
184graph_spec = tensorflow_gnn.create_graph_spec_from_schema_pb(schema)
185graph_tensor = tensorflow_gnn.parse_single_example(
186graph_spec, tf_example.SerializeToString())
187graph_struct = sdtfgnn.graph_struct_from_graph_tensor(graph_tensor)
188self._assert_correct(graph_struct)
189
190
191class NumpyIOTest(_BaseIOTest):
192
193def test_graph_struct_from_tf_example(self):
194tf_example = _make_tf_example()
195schema = _make_schema()
196graph_struct = sdtfgnn.graph_struct_from_tf_example(
197tf_example, schema, engine=sdnp.engine)
198self._assert_correct(graph_struct)
199
200def test_graph_struct_from_tfgnn_graph_tensor(self):
201tf_example = _make_tf_example()
202schema = _make_schema()
203graph_spec = tensorflow_gnn.create_graph_spec_from_schema_pb(schema)
204graph_tensor = tensorflow_gnn.parse_single_example(
205graph_spec, tf_example.SerializeToString())
206graph_struct = sdtfgnn.graph_struct_from_graph_tensor(
207graph_tensor, engine=sdnp.engine)
208self._assert_correct(graph_struct)
209
210
211class JaxIOTest(tf.test.TestCase):
212"""Jax does not support string features, therefore a (simple) modified graph.
213
214https://github.com/google/jax/issues/2084
215"""
216
217def test_graph_struct_from_graph_tensor(self):
218graph_tensor = tensorflow_gnn.GraphTensor.from_pieces(
219context=tensorflow_gnn.Context.from_fields(
220features={'x': tf.constant([1])}),
221node_sets={
222'n1': tensorflow_gnn.NodeSet.from_fields(
223features={'f1': tf.constant([2, 3, 4])},
224sizes=tf.constant([3])),
225'n2': tensorflow_gnn.NodeSet.from_fields(
226features={'f2': tf.constant([[-1.0], [-2.0]])},
227sizes=tf.constant([2])),
228},
229edge_sets={
230'e': tensorflow_gnn.EdgeSet.from_fields(
231adjacency=tensorflow_gnn.Adjacency.from_indices(
232source=('n1', tf.constant([0, 0, 1, 2])),
233target=('n2', tf.constant([0, 1, 0, 1]))),
234features={'f': tf.constant([1, 2, 3, 4])},
235sizes=tf.constant([4])),
236}
237)
238graph_struct = sdtfgnn.graph_struct_from_graph_tensor(
239graph_tensor, engine=sdjax.engine)
240self.assertAllEqual(graph_struct.schema['e'], ('n1', 'n2'))
241self.assertAllEqual(graph_struct.nodes['n1']['f1'], [2, 3, 4])
242self.assertAllEqual(graph_struct.nodes['n2']['f2'], [[-1.0], [-2.0]])
243edge_endpoints, edge_features = graph_struct.edges['e']
244self.assertAllEqual(edge_features['f'], [1, 2, 3, 4])
245self.assertAllEqual(edge_endpoints[0], [0, 0, 1, 2])
246self.assertAllEqual(edge_endpoints[1], [0, 1, 0, 1])
247
248self.assertIsInstance(edge_endpoints[0], jax.Array)
249self.assertIsInstance(edge_endpoints[1], jax.Array)
250self.assertIsInstance(edge_features['f'], jax.Array)
251num_processed_features = 0
252for features in graph_struct.nodes.values():
253for feature_value in features.values():
254num_processed_features += 1
255self.assertIsInstance(feature_value, jax.Array)
256self.assertEqual(num_processed_features, 3) # two node + one graph feats.
257
258
259if __name__ == '__main__':
260tf.test.main()
261