google-research

Форк
0
119 строк · 4.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tree structure helper classes and functions."""
17

18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
from absl import flags
24
import numpy as np
25
import tensorflow.compat.v2 as tf
26

27
FLAGS = flags.FLAGS
28

29

30
class Tree(object):
31
  """Tree class for tree based colaborative filtering model.
32

33
  Module to define basic tree operations, as updating the tree sturcture,
34
  providing tree stats and more.
35
  Attributes:
36
    children: dict of dicts. For each level l (where -1 is the root level)
37
      children[l] holds as keys the indices of the used nodes in this level,
38
      and their values are are lists of their children node indices. Note that
39
      item indices are given a unique id by adding the number of total nodes
40
      to the item index as in tree_based/models/base
41
  """
42

43
  def __init__(self, tot_levels):
44
    self.children = {}
45
    for l in range(tot_levels+1):
46
      self.children[l-1] = {}
47

48
  def update(self, node_ind, parent, parent_level):
49
    """Adds a node to the tree, should be used bottom up."""
50
    if parent in self.children[parent_level]:
51
      self.children[parent_level][parent].append(node_ind)
52
    else:
53
      self.children[parent_level][parent] = [node_ind]
54

55
  def as_ragged(self, tot_nodes):
56
    """Convernt to a ragged tensor."""
57
    all_child = [[] for i in range(tot_nodes+1)]
58
    for level in self.children:
59
      for parent in self.children[level]:
60
        all_child[int(parent)] = [
61
            int(child) for child in self.children[level][parent]
62
        ]
63
    return tf.ragged.constant(all_child)
64

65
  def stats(self):
66
    """Calculates tree stats: num of used nodes, mean and std of node degrees."""
67
    used_nodes = {}
68
    mean_deg = {}
69
    std_deg = {}
70
    for level in self.children:
71
      used_nodes[level] = len(self.children[level].keys())
72
      degs = []
73
      for parent in self.children[level]:
74
        degs.append(len(self.children[level][parent]))
75
      mean_deg[level] = np.mean(degs)
76
      std_deg[level] = np.std(degs)
77
    return used_nodes, mean_deg, std_deg
78

79

80
def top_k_to_scores(top_k_rec, n_items):
81
  k = len(top_k_rec)
82
  scores = np.zeros(n_items)
83
  for i, rec in enumerate(top_k_rec):
84
    scores[int(rec)] = k-i
85
  return scores
86

87

88
def build_tree(closest_node_to_items, closest_node_to_nodes, nodes_per_level):
89
  """builds the item-nodes tree based on closest nodes.
90

91
  Builds the tree borrom up. Skips nodes that are not connected to any item.
92

93
  Args:
94
      closest_node_to_items: np.array of size (n_item, ) where
95
        closest_node_to_items[item_index] = closest node index one level up.
96
      closest_node_to_nodes: np.array of size (tot_n_nodes, ) where
97
        closest_node_to_nodes[node_index] = closest node index one level up.
98
      nodes_per_level: list of the number of nodes per level excluding the
99
        root and the leaves.
100

101
  Returns:
102
      tree: Tree class.
103
  """
104
  # root index is -1
105
  tot_levels = len(nodes_per_level)
106
  tree = Tree(tot_levels)
107
  # add leaves
108
  for leaf, node_parent in enumerate(closest_node_to_items):
109
    leaf = sum(nodes_per_level) + leaf  # unique leaf id
110
    tree.update(leaf, parent=node_parent, parent_level=tot_levels-1)
111
  # add internal nodes, bottom-up
112
  for level in range(tot_levels-1, -1, -1):
113
    first_node = sum(nodes_per_level[:level])
114
    last = sum(nodes_per_level[:level+1])
115
    for node in range(first_node, last):
116
      node_parent = closest_node_to_nodes[node]
117
      if node in tree.children[level]:
118
        tree.update(node, parent=node_parent, parent_level=level-1)
119
  return tree
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.