google-research
119 строк · 4.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tree structure helper classes and functions."""
17
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from absl import flags
24import numpy as np
25import tensorflow.compat.v2 as tf
26
27FLAGS = flags.FLAGS
28
29
30class Tree(object):
31"""Tree class for tree based colaborative filtering model.
32
33Module to define basic tree operations, as updating the tree sturcture,
34providing tree stats and more.
35Attributes:
36children: dict of dicts. For each level l (where -1 is the root level)
37children[l] holds as keys the indices of the used nodes in this level,
38and their values are are lists of their children node indices. Note that
39item indices are given a unique id by adding the number of total nodes
40to the item index as in tree_based/models/base
41"""
42
43def __init__(self, tot_levels):
44self.children = {}
45for l in range(tot_levels+1):
46self.children[l-1] = {}
47
48def update(self, node_ind, parent, parent_level):
49"""Adds a node to the tree, should be used bottom up."""
50if parent in self.children[parent_level]:
51self.children[parent_level][parent].append(node_ind)
52else:
53self.children[parent_level][parent] = [node_ind]
54
55def as_ragged(self, tot_nodes):
56"""Convernt to a ragged tensor."""
57all_child = [[] for i in range(tot_nodes+1)]
58for level in self.children:
59for parent in self.children[level]:
60all_child[int(parent)] = [
61int(child) for child in self.children[level][parent]
62]
63return tf.ragged.constant(all_child)
64
65def stats(self):
66"""Calculates tree stats: num of used nodes, mean and std of node degrees."""
67used_nodes = {}
68mean_deg = {}
69std_deg = {}
70for level in self.children:
71used_nodes[level] = len(self.children[level].keys())
72degs = []
73for parent in self.children[level]:
74degs.append(len(self.children[level][parent]))
75mean_deg[level] = np.mean(degs)
76std_deg[level] = np.std(degs)
77return used_nodes, mean_deg, std_deg
78
79
80def top_k_to_scores(top_k_rec, n_items):
81k = len(top_k_rec)
82scores = np.zeros(n_items)
83for i, rec in enumerate(top_k_rec):
84scores[int(rec)] = k-i
85return scores
86
87
88def build_tree(closest_node_to_items, closest_node_to_nodes, nodes_per_level):
89"""builds the item-nodes tree based on closest nodes.
90
91Builds the tree borrom up. Skips nodes that are not connected to any item.
92
93Args:
94closest_node_to_items: np.array of size (n_item, ) where
95closest_node_to_items[item_index] = closest node index one level up.
96closest_node_to_nodes: np.array of size (tot_n_nodes, ) where
97closest_node_to_nodes[node_index] = closest node index one level up.
98nodes_per_level: list of the number of nodes per level excluding the
99root and the leaves.
100
101Returns:
102tree: Tree class.
103"""
104# root index is -1
105tot_levels = len(nodes_per_level)
106tree = Tree(tot_levels)
107# add leaves
108for leaf, node_parent in enumerate(closest_node_to_items):
109leaf = sum(nodes_per_level) + leaf # unique leaf id
110tree.update(leaf, parent=node_parent, parent_level=tot_levels-1)
111# add internal nodes, bottom-up
112for level in range(tot_levels-1, -1, -1):
113first_node = sum(nodes_per_level[:level])
114last = sum(nodes_per_level[:level+1])
115for node in range(first_node, last):
116node_parent = closest_node_to_nodes[node]
117if node in tree.children[level]:
118tree.update(node, parent=node_parent, parent_level=level-1)
119return tree
120