google-research

Форк
0
/
base_graph_test.py 
97 строк · 3.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Base test classes for graph API."""
17
from absl.testing import absltest
18

19
import jax
20
import jax.numpy as jnp
21

22
from jaxsel import agents
23
from jaxsel import graph_models
24
from jaxsel import subgraph_extractors
25
from jaxsel import synthetic_data
26

27

28
class BaseGraphTest(absltest.TestCase):
29
  """Tests basic functionnality of our Graph and Agent APIs."""
30

31
  def setUp(self):
32
    """Sets up an example image and associated graph and agent."""
33
    super().setUp()
34

35
    grid_size = 20
36
    num_paths = 2
37
    num_classes = 3
38

39
    agent_hidden_dim = 4
40
    max_graph_size = 75
41
    max_subgraph_size = 11
42
    num_steps_extractor = 50
43
    rho = 1e-6
44
    alpha = 1e-3
45
    ridge = 1e-7
46
    num_heads = 2
47
    n_encoder_layers = 2
48
    qkv_dim = 32
49
    mlp_dim = 64
50
    embedding_dim = 4
51
    hidden_dim = 16
52

53
    graph, start_node_id, label = synthetic_data.generate(
54
        grid_size, num_paths, num_classes)
55

56
    self.graph = graph
57
    self.start_node_id = start_node_id
58
    self.label = label
59

60
    agent_config = agents.AgentConfig(graph.graph_parameters(),
61
                                      agent_hidden_dim, agent_hidden_dim)
62

63
    extractor_config = subgraph_extractors.ExtractorConfig(
64
        max_graph_size, max_subgraph_size, rho, alpha, num_steps_extractor,
65
        ridge, agent_config)
66

67
    graph_classifier_config = graph_models.TransformerConfig(
68
        graph.graph_parameters(),
69
        num_heads=num_heads,
70
        num_layers=n_encoder_layers,
71
        qkv_dim=qkv_dim,
72
        mlp_dim=mlp_dim,
73
        image_size=grid_size**2,
74
        embedding_dim=embedding_dim,
75
        hidden_dim=hidden_dim,
76
        num_classes=num_classes)
77

78
    extractor = subgraph_extractors.SparseISTAExtractor(extractor_config)
79
    agent = agents.SimpleFiLMedAgentModel(agent_config)
80

81
    self.rng = jax.random.PRNGKey(2)
82
    self.agent = agent
83
    self.extractor = extractor
84
    self.graph_classifier = graph_models.TransformerClassifier(
85
        graph_classifier_config)
86

87
  def test_random_walk_on_graph(self):
88
    """Tests ability to perform a random walk on a graph built from an image."""
89

90
    rng_agent, self.rng = jax.random.split(self.rng)
91
    self.agent.init_with_output(rng_agent, self.graph, method=self.agent.walk)
92

93
  def test_out_of_bounds_pixel_neighbors(self):
94
    """The out of bounds pixel should only be linked to the start node."""
95
    relation_ids, neighbor_node_ids = self.graph.outgoing_edges(-1)
96
    del relation_ids
97
    assert jnp.all(neighbor_node_ids == self.graph._start_node_id)
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.