5
// Hierarchical Softmax protobuffer convention:
6
// The HSM operator requires a hierarchy of vocabulary words in the form of a
7
// tree from the user. This tree is expressed using the proto format.
8
// TreeProto points to the root NodeProto which can recursively contain children
9
// NodeProtos (internal nodes) or word_ids (leaf nodes).
11
// The aforementioned TreeProto is internally translated into a list of word_ids
12
// tagged with a list of NodeProtos that lie in the path from the root to that
13
// word_id using hsm_util.create_hierarchy(tree_proto).
14
// Specifically, HierarchyProto contains a list of PathProtos. Each PathProto
15
// belongs to a word_id and contains a list of PathNodeProtos. Each
16
// PathNodeProto contains information about the number of children the node has
17
// (length), the index of the child node that lies in the path from root to
18
// word_id (target) and a cumulative sum of children nodes (index; this acts as
19
// the weight parameter matrix offset).
21
// Each node in the hierarchy contains links to either leaf nodes or more
24
// Links to non-terminal children nodes
25
repeated NodeProto children = 1;
26
// Links to terminal (leaf) nodes
27
repeated int32 word_ids = 2;
28
optional int32 offset = 3;
29
optional string name = 4;
30
repeated float scores = 5;
33
// Protobuf format to accept hierarchy for hierarchical softmax operator.
34
// TreeProto points to the root node.
36
optional NodeProto root_node = 1;
39
// Internal Protobuf format which represents the path in the tree hierarchy for
40
// each word in the vocabulary.
41
message HierarchyProto {
42
optional int32 size = 1;
43
repeated PathProto paths = 2;
46
// Each PathProto belongs to a word and is an array of nodes in the
47
// path from the root to the leaf (which is the word itself) in the tree.
49
optional int32 word_id = 1;
50
repeated PathNodeProto path_nodes = 2;
53
// Represents a node in the path from the root node all the way down to the
55
message PathNodeProto {
56
// Parameter matrix offset for this node
57
optional int32 index = 1;
59
optional int32 length = 2;
60
// Index of the next node in the path
61
optional int32 target = 3;