google-research
556 строк · 17.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Encode tokens, entity references and predictions as numerical vectors."""
17
18import inspect19import json20import os21import sys22from typing import Any, List, Optional, Text, Tuple, Type, Union23
24from absl import logging25import numpy as np26import tensorflow as tf27
28MAX_NUM_ENTITIES = 2029
30EnrefArray = Union[tf.Tensor, np.ndarray]31
32
33class Section(object):34"""Represents a section (i.e. a range) within a data array."""35
36def __init__(self, array, start, size):37self.array = array38self.start = start39self.size = size40
41def slice(self):42return self.array[Ellipsis, self.start:(self.start + self.size)]43
44def replace(self, array):45if isinstance(self.array, tf.Tensor):46self.array = tf.concat([47self.array[Ellipsis, :self.start], array,48self.array[Ellipsis, (self.start + self.size):]49], -1)50else:51self.array[Ellipsis, self.start:(self.start + self.size)] = array52return self.array53
54
55class TypeSection(Section):56"""A section which specifies the encoding type (enref, token, prediction)."""57SIZE = 358
59def is_token(self):60return self.array[Ellipsis, self.start + 2]61
62def set_token(self):63self.array[Ellipsis, self.start] = 0.064self.array[Ellipsis, self.start + 2] = 1.065
66def is_enref(self):67return self.array[Ellipsis, self.start]68
69def set_enref(self):70self.array[Ellipsis, self.start] = 1.071self.array[Ellipsis, self.start + 2] = 0.072
73
74class EnrefMetaSection(Section):75"""Encodes whether a token is an enref and if its new or new continued."""76SIZE = 377
78def is_enref(self):79return self.array[Ellipsis, self.start]80
81def set_is_enref(self, value):82self.array[Ellipsis, self.start] = 1.0 if value else 0.083
84def is_new(self):85return self.array[Ellipsis, self.start + 1]86
87def set_is_new(self, value):88self.array[Ellipsis, self.start + 1] = 1.0 if value else 0.089
90def is_new_continued(self):91return self.array[Ellipsis, self.start + 2]92
93def set_is_new_continued(self, value):94self.array[Ellipsis, self.start + 2] = 1.0 if value else 0.095
96def get_is_new_slice(self):97return self.array[Ellipsis, self.start + 1:self.start + self.size]98
99def replace_is_new_slice(self, array):100self.array = tf.concat([101self.array[Ellipsis, :self.start + 1], array,102self.array[Ellipsis, (self.start + self.size):]103], -1)104return self.array105
106
107class EnrefIdSection(Section):108SIZE = MAX_NUM_ENTITIES109
110def get(self):111index = np.argmax(self.slice())112return index113
114def set(self, enref_id):115self.array[Ellipsis, self.start:(self.start + self.size)] = 0.0116self.array[Ellipsis, self.start + enref_id] = 1.0117
118
119class EnrefPropertiesSection(Section):120"""Encodes the grammatical gender and whether an enref is a group."""121SIZE = 6122DOMAINS = ['people', 'locations']123PROPERTIES = ['female', 'male', 'neuter']124
125def get_domain(self):126array = self.array[Ellipsis, self.start:self.start + 2]127if np.max(array) <= 0.0:128return 'unknown'129index = np.argmax(array)130return self.DOMAINS[index]131
132def set_domain(self, domain):133self.array[Ellipsis, self.start:(self.start + 2)] = 0.0134if domain == 'unknown':135return136index = self.DOMAINS.index(domain)137self.array[Ellipsis, self.start + index] = 1.0138
139def get_gender(self):140array = self.array[Ellipsis, (self.start + 2):(self.start + 5)]141if np.max(array) <= 0.0:142return 'unknown'143index = np.argmax(array)144return self.PROPERTIES[index]145
146def set_gender(self, gender):147self.array[Ellipsis, (self.start + 2):(self.start + 5)] = 0.0148if gender == 'unknown':149return150index = self.PROPERTIES.index(gender)151self.array[Ellipsis, self.start + 2 + index] = 1.0152
153def is_group(self):154return self.array[Ellipsis, self.start + 5]155
156def set_is_group(self, value):157self.array[Ellipsis, self.start + 5] = 1.0 if value else 0.0158
159
160class EnrefMembershipSection(Section):161"""Encodes the members of a group, if an enref refers to multiple entities."""162SIZE = MAX_NUM_ENTITIES163
164def __init__(self, array, start, size):165Section.__init__(self, array, start, size)166self.names = None167
168def get_ids(self):169ids = np.where(self.slice() > 0.0)[0].tolist()170return ids171
172def get_names(self):173return self.names174
175def set(self, ids, names = None):176self.names = names177self.array[Ellipsis, self.start:(self.start + self.size)] = 0.0178for enref_id in ids:179self.array[Ellipsis, self.start + enref_id] = 1.0180
181
182class EnrefContextSection(Section):183"""Encodes if an enref is a sender or recipient and the message offset."""184SIZE = 7185
186def is_sender(self):187return self.array[Ellipsis, self.start]188
189def set_is_sender(self, value):190self.array[Ellipsis, self.start] = 1.0 if value else 0.0191
192def is_recipient(self):193return self.array[Ellipsis, self.start + 1]194
195def set_is_recipient(self, value):196self.array[Ellipsis, self.start + 1] = 1.0 if value else 0.0197
198def get_message_offset(self):199digit = 1200message_offset = 0201for i in range(2, self.SIZE):202message_offset += int(self.array[Ellipsis, self.start + i]) * digit203digit *= 2204return message_offset205
206def set_message_offset(self, offset):207for i in range(2, self.SIZE):208if offset & 0x01:209self.array[Ellipsis, self.start + i] = 1.0210else:211self.array[Ellipsis, self.start + i] = 0.0212offset = offset >> 1213
214
215class TokenPaddingSection(Section):216"""An empty section sized so that enref and token encodings align."""217SIZE = (218EnrefIdSection.SIZE + EnrefPropertiesSection.SIZE +219EnrefMembershipSection.SIZE + EnrefContextSection.SIZE)220
221
222class SignalSection(Section):223"""Encodes optional token signals collected during preprocessing."""224SIZE = 10225SIGNALS = {226'first_name': 0,227'sports_team': 1,228'athlete': 2,229}230
231def set(self, signals):232self.array[Ellipsis, self.start:(self.start + self.size)] = 0.0233for signal in signals:234index = self.SIGNALS[signal]235self.array[Ellipsis, self.start + index] = 1.0236
237def get(self):238signals = []239for index, signal in enumerate(self.SIGNALS):240if self.array[Ellipsis, self.start + index] > 0.0:241signals.append(signal)242return signals243
244
245class WordvecSection(Section):246"""Contains the word2vec embedding of a token."""247SIZE = 300248
249def get(self):250return self.slice()251
252def set(self, wordvec):253self.array[Ellipsis, self.start:(self.start + self.size)] = wordvec254
255
256class BertSection(Section):257"""Contains the BERT embedding of a token."""258SIZE = 768259
260def get(self):261return self.slice()262
263def set(self, bertvec):264self.array[Ellipsis, self.start:(self.start + self.size)] = bertvec265
266
267class Encoding(object):268"""Provides an API to access data within an array."""269
270def __init__(self, array, layout):271assert isinstance(array, (np.ndarray, tf.Tensor))272
273self.array = array274index = 0275for (name, section_class) in layout:276section = section_class(array, index, section_class.SIZE)277setattr(self, name, section)278index += section_class.SIZE279
280self.sections_size = index281
282
283class EnrefEncoding(Encoding):284"""An API to access and modify contrack entity references within an array."""285
286def __init__(self, array, layout):287Encoding.__init__(self, array, layout)288
289self.entity_name = None290self.word_span = None291self.span_text = None292
293def populate(self, entity_name, word_span,294span_text):295self.entity_name = entity_name296self.word_span = word_span297self.span_text = span_text298
299def __repr__(self):300descr = ''301if self.entity_name is not None:302descr += '%s ' % self.entity_name303
304descr += '(%d%s%s) ' % (self.enref_id.get(),305'n' if self.enref_meta.is_new() > 0.0 else '', 'c'306if self.enref_meta.is_new_continued() > 0.0 else '')307if self.word_span is not None:308descr += '%d-%d ' % self.word_span309if self.span_text is not None:310descr += '(%s) ' % self.span_text311if self.enref_properties is not None:312is_group = self.enref_properties.is_group() > 0.0313domain = self.enref_properties.get_domain()314descr += domain[0]315if domain == 'people' and not is_group:316descr += ':' + self.enref_properties.get_gender()317if is_group:318descr += ':g %s' % self.enref_membership.get_ids()319if self.signals is not None and self.signals.get():320descr += str(self.signals.get())321return descr322
323
324class TokenEncoding(Encoding):325"""An API to access and modify contrack tokens within an array."""326
327def __init__(self, array, layout):328Encoding.__init__(self, array, layout)329
330def populate(self, token, signals, wordvec,331bertvec):332self.token = token333self.signals.set(signals)334self.wordvec.set(wordvec)335self.bert.set(bertvec)336
337def __repr__(self):338signals = self.signals.get()339signals_str = str(signals) if signals else ''340return '%s%s' % (self.token, signals_str)341
342
343class PredictionEncoding(Encoding):344"""An API to access and modify prediction values stored in an array."""345
346def __init__(self, array, layout):347Encoding.__init__(self, array, layout)348
349def __repr__(self):350descr = '(%d%s%s) ' % (self.enref_id.get(),351'n' if self.enref_meta.is_new() > 0.0 else '', 'c'352if self.enref_meta.is_new_continued() > 0.0 else '')353if self.enref_properties is not None:354is_group = self.enref_properties.is_group() > 0.0355domain = self.enref_properties.get_domain()356descr += domain[0]357if domain == 'people' and not is_group:358descr += ':' + self.enref_properties.get_gender()359if is_group:360descr += ': %s' % self.enref_membership.get_ids()361return descr362
363
364class Encodings(object):365"""Organize access to data encoded in numerical vectors."""366
367def __init__(self):368self.enref_encoding_layout = [('type', TypeSection),369('enref_meta', EnrefMetaSection),370('enref_id', EnrefIdSection),371('enref_properties', EnrefPropertiesSection),372('enref_membership', EnrefMembershipSection),373('enref_context', EnrefContextSection),374('signals', SignalSection),375('wordvec', WordvecSection),376('bert', BertSection)]377self.enref_encoding_length = sum(378[class_name.SIZE for (_, class_name) in self.enref_encoding_layout])379logging.info('EnrefEncoding (length: %d): %s', self.enref_encoding_length,380[f'{s}: {c.SIZE}' for s, c in self.enref_encoding_layout])381
382self.token_encoding_layout = [('type', TypeSection),383('enref_meta', EnrefMetaSection),384('padding', TokenPaddingSection),385('signals', SignalSection),386('wordvec', WordvecSection),387('bert', BertSection)]388self.token_encoding_length = sum(389[class_name.SIZE for (_, class_name) in self.token_encoding_layout])390assert self.enref_encoding_length == self.token_encoding_length391logging.info('TokenEncoding (length: %d): %s', self.token_encoding_length,392[f'{s}: {c.SIZE}' for s, c in self.token_encoding_layout])393
394self.prediction_encoding_layout = [395('enref_meta', EnrefMetaSection),396('enref_id', EnrefIdSection),397('enref_properties', EnrefPropertiesSection),398('enref_membership', EnrefMembershipSection),399]400self.prediction_encoding_length = sum([401class_name.SIZE for (_, class_name) in self.prediction_encoding_layout402])403logging.info('PredictionEncoding (length: %d): %s',404self.prediction_encoding_length,405[f'{s}: {c.SIZE}' for s, c in self.prediction_encoding_layout])406
407@classmethod408def load_from_json(cls, path):409"""Loads the encoding layout from a json file."""410classes = inspect.getmembers(sys.modules[__name__])411with tf.io.gfile.GFile(path, 'r') as file:412encodings_dict = json.loads(file.read())413
414enc = Encodings()415enc.enref_encoding_layout = []416for name, cls_name in encodings_dict['enref_encoding_layout']:417section_cls = next(o for (n, o) in classes if n.endswith(cls_name))418enc.enref_encoding_layout.append((name, section_cls))419enc.enref_encoding_length = sum(420[class_name.SIZE for (_, class_name) in enc.enref_encoding_layout])421
422enc.token_encoding_layout = []423for name, cls_name in encodings_dict['token_encoding_layout']:424section_cls = next(o for (n, o) in classes if n.endswith(cls_name))425enc.token_encoding_layout.append((name, section_cls))426enc.token_encoding_length = sum(427[class_name.SIZE for (_, class_name) in enc.token_encoding_layout])428assert enc.enref_encoding_length == enc.token_encoding_length429
430enc.prediction_encoding_layout = []431for name, cls_name in encodings_dict['prediction_encoding_layout']:432section_cls = next(o for (n, o) in classes if n.endswith(cls_name))433enc.prediction_encoding_layout.append((name, section_cls))434enc.prediction_encoding_length = sum(435[class_name.SIZE for (_, class_name) in enc.prediction_encoding_layout])436
437return enc438
439def as_enref_encoding(self, array):440return EnrefEncoding(array, self.enref_encoding_layout)441
442def new_enref_array(self):443return np.array([0.0] * self.enref_encoding_length)444
445def new_enref_encoding(self):446enc = EnrefEncoding(self.new_enref_array(), self.enref_encoding_layout)447enc.type.set_enref()448return enc449
450def as_token_encoding(self, array):451return TokenEncoding(array, self.token_encoding_layout)452
453def new_token_array(self):454return np.array([0.0] * self.token_encoding_length)455
456def new_token_encoding(self, token, signals,457wordvec, bertvec):458enc = TokenEncoding(self.new_token_array(), self.token_encoding_layout)459enc.type.set_token()460enc.populate(token, signals, wordvec, bertvec)461return enc462
463def as_prediction_encoding(self, array):464return PredictionEncoding(array, self.prediction_encoding_layout)465
466def new_prediction_array(self):467return np.array([0.0] * self.prediction_encoding_length)468
469def new_prediction_encoding(self):470enc = PredictionEncoding(self.new_prediction_array(),471self.prediction_encoding_layout)472return enc473
474def build_enref_from_prediction(475self, token,476prediction):477"""Build new enref from prediction logits."""478if prediction.enref_meta.is_enref() <= 0.0:479return None480
481new_array = np.array(token.array)482enref = self.as_enref_encoding(new_array)483enref.type.set_enref()484
485enref.enref_meta.replace(486np.where(prediction.enref_meta.slice() > 0.0, 1.0, 0.0))487enref.enref_id.set(prediction.enref_id.get())488enref.enref_properties.replace(489np.where(prediction.enref_properties.slice() > 0.0, 1.0, 0.0))490if prediction.enref_properties.is_group() > 0.0:491enref.enref_membership.replace(492np.where(prediction.enref_membership.slice() > 0.0, 1.0, 0.0))493else:494enref.enref_membership.set([])495enref.signals.set([])496
497return enref498
499def build_enrefs_from_predictions(500self, tokens,501predictions,502words,503prev_enrefs):504"""Build new enrefs from prediction logits."""505# Identify spans.506spans = []507current_span = None508for i, pred_enc in enumerate(predictions):509if current_span and (pred_enc.enref_meta.is_enref() <= 0.0 or510current_span[1] != pred_enc.enref_id.get()):511spans.append((current_span[0], i))512current_span = None513if not current_span and pred_enc.enref_meta.is_enref() > 0.0:514current_span = (i, pred_enc.enref_id.get())515if current_span:516spans.append((current_span[0], len(predictions)))517
518# Create enrefs for spans519enrefs = []520for (start, end) in spans:521enref = self.build_enref_from_prediction(tokens[start],522predictions[start])523enref.wordvec.set(np.mean([tokens[i].wordvec.get()524for i in range(start, end)], 0))525enref.bert.set(np.mean([tokens[i].bert.get()526for i in range(start, end)], 0))527span_text = ' '.join([words[i] for i in range(start, end)])528
529name = words[start]530if enref.enref_meta.is_new() <= 0.0:531for e in prev_enrefs:532if e.enref_id.get() == enref.enref_id.get():533name = e.entity_name534break535enref.populate(name, (start, end), span_text)536enrefs.append(enref)537
538return enrefs539
540def save(self, path):541"""Saves encoding to json file."""542encodings_dict = {543'enref_encoding_layout': [544(n, c.__name__) for (n, c) in self.enref_encoding_layout545],546'token_encoding_layout': [547(n, c.__name__) for (n, c) in self.token_encoding_layout548],549'prediction_encoding_layout': [550(n, c.__name__) for (n, c) in self.prediction_encoding_layout551],552}553
554filepath = os.path.join(path, 'encodings.json')555with tf.io.gfile.GFile(filepath, 'w') as file:556json.dump(encodings_dict, file, indent=2)557