google-research

Форк
0
556 строк · 17.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Encode tokens, entity references and predictions as numerical vectors."""
17

18
import inspect
19
import json
20
import os
21
import sys
22
from typing import Any, List, Optional, Text, Tuple, Type, Union
23

24
from absl import logging
25
import numpy as np
26
import tensorflow as tf
27

28
MAX_NUM_ENTITIES = 20
29

30
EnrefArray = Union[tf.Tensor, np.ndarray]
31

32

33
class Section(object):
34
  """Represents a section (i.e. a range) within a data array."""
35

36
  def __init__(self, array, start, size):
37
    self.array = array
38
    self.start = start
39
    self.size = size
40

41
  def slice(self):
42
    return self.array[Ellipsis, self.start:(self.start + self.size)]
43

44
  def replace(self, array):
45
    if isinstance(self.array, tf.Tensor):
46
      self.array = tf.concat([
47
          self.array[Ellipsis, :self.start], array,
48
          self.array[Ellipsis, (self.start + self.size):]
49
      ], -1)
50
    else:
51
      self.array[Ellipsis, self.start:(self.start + self.size)] = array
52
    return self.array
53

54

55
class TypeSection(Section):
56
  """A section which specifies the encoding type (enref, token, prediction)."""
57
  SIZE = 3
58

59
  def is_token(self):
60
    return self.array[Ellipsis, self.start + 2]
61

62
  def set_token(self):
63
    self.array[Ellipsis, self.start] = 0.0
64
    self.array[Ellipsis, self.start + 2] = 1.0
65

66
  def is_enref(self):
67
    return self.array[Ellipsis, self.start]
68

69
  def set_enref(self):
70
    self.array[Ellipsis, self.start] = 1.0
71
    self.array[Ellipsis, self.start + 2] = 0.0
72

73

74
class EnrefMetaSection(Section):
75
  """Encodes whether a token is an enref and if its new or new continued."""
76
  SIZE = 3
77

78
  def is_enref(self):
79
    return self.array[Ellipsis, self.start]
80

81
  def set_is_enref(self, value):
82
    self.array[Ellipsis, self.start] = 1.0 if value else 0.0
83

84
  def is_new(self):
85
    return self.array[Ellipsis, self.start + 1]
86

87
  def set_is_new(self, value):
88
    self.array[Ellipsis, self.start + 1] = 1.0 if value else 0.0
89

90
  def is_new_continued(self):
91
    return self.array[Ellipsis, self.start + 2]
92

93
  def set_is_new_continued(self, value):
94
    self.array[Ellipsis, self.start + 2] = 1.0 if value else 0.0
95

96
  def get_is_new_slice(self):
97
    return self.array[Ellipsis, self.start + 1:self.start + self.size]
98

99
  def replace_is_new_slice(self, array):
100
    self.array = tf.concat([
101
        self.array[Ellipsis, :self.start + 1], array,
102
        self.array[Ellipsis, (self.start + self.size):]
103
    ], -1)
104
    return self.array
105

106

107
class EnrefIdSection(Section):
108
  SIZE = MAX_NUM_ENTITIES
109

110
  def get(self):
111
    index = np.argmax(self.slice())
112
    return index
113

114
  def set(self, enref_id):
115
    self.array[Ellipsis, self.start:(self.start + self.size)] = 0.0
116
    self.array[Ellipsis, self.start + enref_id] = 1.0
117

118

119
class EnrefPropertiesSection(Section):
120
  """Encodes the grammatical gender and whether an enref is a group."""
121
  SIZE = 6
122
  DOMAINS = ['people', 'locations']
123
  PROPERTIES = ['female', 'male', 'neuter']
124

125
  def get_domain(self):
126
    array = self.array[Ellipsis, self.start:self.start + 2]
127
    if np.max(array) <= 0.0:
128
      return 'unknown'
129
    index = np.argmax(array)
130
    return self.DOMAINS[index]
131

132
  def set_domain(self, domain):
133
    self.array[Ellipsis, self.start:(self.start + 2)] = 0.0
134
    if domain == 'unknown':
135
      return
136
    index = self.DOMAINS.index(domain)
137
    self.array[Ellipsis, self.start + index] = 1.0
138

139
  def get_gender(self):
140
    array = self.array[Ellipsis, (self.start + 2):(self.start + 5)]
141
    if np.max(array) <= 0.0:
142
      return 'unknown'
143
    index = np.argmax(array)
144
    return self.PROPERTIES[index]
145

146
  def set_gender(self, gender):
147
    self.array[Ellipsis, (self.start + 2):(self.start + 5)] = 0.0
148
    if gender == 'unknown':
149
      return
150
    index = self.PROPERTIES.index(gender)
151
    self.array[Ellipsis, self.start + 2 + index] = 1.0
152

153
  def is_group(self):
154
    return self.array[Ellipsis, self.start + 5]
155

156
  def set_is_group(self, value):
157
    self.array[Ellipsis, self.start + 5] = 1.0 if value else 0.0
158

159

160
class EnrefMembershipSection(Section):
161
  """Encodes the members of a group, if an enref refers to multiple entities."""
162
  SIZE = MAX_NUM_ENTITIES
163

164
  def __init__(self, array, start, size):
165
    Section.__init__(self, array, start, size)
166
    self.names = None
167

168
  def get_ids(self):
169
    ids = np.where(self.slice() > 0.0)[0].tolist()
170
    return ids
171

172
  def get_names(self):
173
    return self.names
174

175
  def set(self, ids, names = None):
176
    self.names = names
177
    self.array[Ellipsis, self.start:(self.start + self.size)] = 0.0
178
    for enref_id in ids:
179
      self.array[Ellipsis, self.start + enref_id] = 1.0
180

181

182
class EnrefContextSection(Section):
183
  """Encodes if an enref is a sender or recipient and the message offset."""
184
  SIZE = 7
185

186
  def is_sender(self):
187
    return self.array[Ellipsis, self.start]
188

189
  def set_is_sender(self, value):
190
    self.array[Ellipsis, self.start] = 1.0 if value else 0.0
191

192
  def is_recipient(self):
193
    return self.array[Ellipsis, self.start + 1]
194

195
  def set_is_recipient(self, value):
196
    self.array[Ellipsis, self.start + 1] = 1.0 if value else 0.0
197

198
  def get_message_offset(self):
199
    digit = 1
200
    message_offset = 0
201
    for i in range(2, self.SIZE):
202
      message_offset += int(self.array[Ellipsis, self.start + i]) * digit
203
      digit *= 2
204
    return message_offset
205

206
  def set_message_offset(self, offset):
207
    for i in range(2, self.SIZE):
208
      if offset & 0x01:
209
        self.array[Ellipsis, self.start + i] = 1.0
210
      else:
211
        self.array[Ellipsis, self.start + i] = 0.0
212
      offset = offset >> 1
213

214

215
class TokenPaddingSection(Section):
216
  """An empty section sized so that enref and token encodings align."""
217
  SIZE = (
218
      EnrefIdSection.SIZE + EnrefPropertiesSection.SIZE +
219
      EnrefMembershipSection.SIZE + EnrefContextSection.SIZE)
220

221

222
class SignalSection(Section):
223
  """Encodes optional token signals collected during preprocessing."""
224
  SIZE = 10
225
  SIGNALS = {
226
      'first_name': 0,
227
      'sports_team': 1,
228
      'athlete': 2,
229
  }
230

231
  def set(self, signals):
232
    self.array[Ellipsis, self.start:(self.start + self.size)] = 0.0
233
    for signal in signals:
234
      index = self.SIGNALS[signal]
235
      self.array[Ellipsis, self.start + index] = 1.0
236

237
  def get(self):
238
    signals = []
239
    for index, signal in enumerate(self.SIGNALS):
240
      if self.array[Ellipsis, self.start + index] > 0.0:
241
        signals.append(signal)
242
    return signals
243

244

245
class WordvecSection(Section):
246
  """Contains the word2vec embedding of a token."""
247
  SIZE = 300
248

249
  def get(self):
250
    return self.slice()
251

252
  def set(self, wordvec):
253
    self.array[Ellipsis, self.start:(self.start + self.size)] = wordvec
254

255

256
class BertSection(Section):
257
  """Contains the BERT embedding of a token."""
258
  SIZE = 768
259

260
  def get(self):
261
    return self.slice()
262

263
  def set(self, bertvec):
264
    self.array[Ellipsis, self.start:(self.start + self.size)] = bertvec
265

266

267
class Encoding(object):
268
  """Provides an API to access data within an array."""
269

270
  def __init__(self, array, layout):
271
    assert isinstance(array, (np.ndarray, tf.Tensor))
272

273
    self.array = array
274
    index = 0
275
    for (name, section_class) in layout:
276
      section = section_class(array, index, section_class.SIZE)
277
      setattr(self, name, section)
278
      index += section_class.SIZE
279

280
    self.sections_size = index
281

282

283
class EnrefEncoding(Encoding):
284
  """An API to access and modify contrack entity references within an array."""
285

286
  def __init__(self, array, layout):
287
    Encoding.__init__(self, array, layout)
288

289
    self.entity_name = None
290
    self.word_span = None
291
    self.span_text = None
292

293
  def populate(self, entity_name, word_span,
294
               span_text):
295
    self.entity_name = entity_name
296
    self.word_span = word_span
297
    self.span_text = span_text
298

299
  def __repr__(self):
300
    descr = ''
301
    if self.entity_name is not None:
302
      descr += '%s ' % self.entity_name
303

304
    descr += '(%d%s%s) ' % (self.enref_id.get(),
305
                            'n' if self.enref_meta.is_new() > 0.0 else '', 'c'
306
                            if self.enref_meta.is_new_continued() > 0.0 else '')
307
    if self.word_span is not None:
308
      descr += '%d-%d ' % self.word_span
309
    if self.span_text is not None:
310
      descr += '(%s) ' % self.span_text
311
    if self.enref_properties is not None:
312
      is_group = self.enref_properties.is_group() > 0.0
313
      domain = self.enref_properties.get_domain()
314
      descr += domain[0]
315
      if domain == 'people' and not is_group:
316
        descr += ':' + self.enref_properties.get_gender()
317
      if is_group:
318
        descr += ':g %s' % self.enref_membership.get_ids()
319
    if self.signals is not None and self.signals.get():
320
      descr += str(self.signals.get())
321
    return descr
322

323

324
class TokenEncoding(Encoding):
325
  """An API to access and modify contrack tokens within an array."""
326

327
  def __init__(self, array, layout):
328
    Encoding.__init__(self, array, layout)
329

330
  def populate(self, token, signals, wordvec,
331
               bertvec):
332
    self.token = token
333
    self.signals.set(signals)
334
    self.wordvec.set(wordvec)
335
    self.bert.set(bertvec)
336

337
  def __repr__(self):
338
    signals = self.signals.get()
339
    signals_str = str(signals) if signals else ''
340
    return '%s%s' % (self.token, signals_str)
341

342

343
class PredictionEncoding(Encoding):
344
  """An API to access and modify prediction values stored in an array."""
345

346
  def __init__(self, array, layout):
347
    Encoding.__init__(self, array, layout)
348

349
  def __repr__(self):
350
    descr = '(%d%s%s) ' % (self.enref_id.get(),
351
                           'n' if self.enref_meta.is_new() > 0.0 else '', 'c'
352
                           if self.enref_meta.is_new_continued() > 0.0 else '')
353
    if self.enref_properties is not None:
354
      is_group = self.enref_properties.is_group() > 0.0
355
      domain = self.enref_properties.get_domain()
356
      descr += domain[0]
357
      if domain == 'people' and not is_group:
358
        descr += ':' + self.enref_properties.get_gender()
359
      if is_group:
360
        descr += ': %s' % self.enref_membership.get_ids()
361
    return descr
362

363

364
class Encodings(object):
365
  """Organize access to data encoded in numerical vectors."""
366

367
  def __init__(self):
368
    self.enref_encoding_layout = [('type', TypeSection),
369
                                  ('enref_meta', EnrefMetaSection),
370
                                  ('enref_id', EnrefIdSection),
371
                                  ('enref_properties', EnrefPropertiesSection),
372
                                  ('enref_membership', EnrefMembershipSection),
373
                                  ('enref_context', EnrefContextSection),
374
                                  ('signals', SignalSection),
375
                                  ('wordvec', WordvecSection),
376
                                  ('bert', BertSection)]
377
    self.enref_encoding_length = sum(
378
        [class_name.SIZE for (_, class_name) in self.enref_encoding_layout])
379
    logging.info('EnrefEncoding (length: %d): %s', self.enref_encoding_length,
380
                 [f'{s}: {c.SIZE}' for s, c in self.enref_encoding_layout])
381

382
    self.token_encoding_layout = [('type', TypeSection),
383
                                  ('enref_meta', EnrefMetaSection),
384
                                  ('padding', TokenPaddingSection),
385
                                  ('signals', SignalSection),
386
                                  ('wordvec', WordvecSection),
387
                                  ('bert', BertSection)]
388
    self.token_encoding_length = sum(
389
        [class_name.SIZE for (_, class_name) in self.token_encoding_layout])
390
    assert self.enref_encoding_length == self.token_encoding_length
391
    logging.info('TokenEncoding (length: %d): %s', self.token_encoding_length,
392
                 [f'{s}: {c.SIZE}' for s, c in self.token_encoding_layout])
393

394
    self.prediction_encoding_layout = [
395
        ('enref_meta', EnrefMetaSection),
396
        ('enref_id', EnrefIdSection),
397
        ('enref_properties', EnrefPropertiesSection),
398
        ('enref_membership', EnrefMembershipSection),
399
    ]
400
    self.prediction_encoding_length = sum([
401
        class_name.SIZE for (_, class_name) in self.prediction_encoding_layout
402
    ])
403
    logging.info('PredictionEncoding (length: %d): %s',
404
                 self.prediction_encoding_length,
405
                 [f'{s}: {c.SIZE}' for s, c in self.prediction_encoding_layout])
406

407
  @classmethod
408
  def load_from_json(cls, path):
409
    """Loads the encoding layout from a json file."""
410
    classes = inspect.getmembers(sys.modules[__name__])
411
    with tf.io.gfile.GFile(path, 'r') as file:
412
      encodings_dict = json.loads(file.read())
413

414
    enc = Encodings()
415
    enc.enref_encoding_layout = []
416
    for name, cls_name in encodings_dict['enref_encoding_layout']:
417
      section_cls = next(o for (n, o) in classes if n.endswith(cls_name))
418
      enc.enref_encoding_layout.append((name, section_cls))
419
    enc.enref_encoding_length = sum(
420
        [class_name.SIZE for (_, class_name) in enc.enref_encoding_layout])
421

422
    enc.token_encoding_layout = []
423
    for name, cls_name in encodings_dict['token_encoding_layout']:
424
      section_cls = next(o for (n, o) in classes if n.endswith(cls_name))
425
      enc.token_encoding_layout.append((name, section_cls))
426
    enc.token_encoding_length = sum(
427
        [class_name.SIZE for (_, class_name) in enc.token_encoding_layout])
428
    assert enc.enref_encoding_length == enc.token_encoding_length
429

430
    enc.prediction_encoding_layout = []
431
    for name, cls_name in encodings_dict['prediction_encoding_layout']:
432
      section_cls = next(o for (n, o) in classes if n.endswith(cls_name))
433
      enc.prediction_encoding_layout.append((name, section_cls))
434
    enc.prediction_encoding_length = sum(
435
        [class_name.SIZE for (_, class_name) in enc.prediction_encoding_layout])
436

437
    return enc
438

439
  def as_enref_encoding(self, array):
440
    return EnrefEncoding(array, self.enref_encoding_layout)
441

442
  def new_enref_array(self):
443
    return np.array([0.0] * self.enref_encoding_length)
444

445
  def new_enref_encoding(self):
446
    enc = EnrefEncoding(self.new_enref_array(), self.enref_encoding_layout)
447
    enc.type.set_enref()
448
    return enc
449

450
  def as_token_encoding(self, array):
451
    return TokenEncoding(array, self.token_encoding_layout)
452

453
  def new_token_array(self):
454
    return np.array([0.0] * self.token_encoding_length)
455

456
  def new_token_encoding(self, token, signals,
457
                         wordvec, bertvec):
458
    enc = TokenEncoding(self.new_token_array(), self.token_encoding_layout)
459
    enc.type.set_token()
460
    enc.populate(token, signals, wordvec, bertvec)
461
    return enc
462

463
  def as_prediction_encoding(self, array):
464
    return PredictionEncoding(array, self.prediction_encoding_layout)
465

466
  def new_prediction_array(self):
467
    return np.array([0.0] * self.prediction_encoding_length)
468

469
  def new_prediction_encoding(self):
470
    enc = PredictionEncoding(self.new_prediction_array(),
471
                             self.prediction_encoding_layout)
472
    return enc
473

474
  def build_enref_from_prediction(
475
      self, token,
476
      prediction):
477
    """Build new enref from prediction logits."""
478
    if prediction.enref_meta.is_enref() <= 0.0:
479
      return None
480

481
    new_array = np.array(token.array)
482
    enref = self.as_enref_encoding(new_array)
483
    enref.type.set_enref()
484

485
    enref.enref_meta.replace(
486
        np.where(prediction.enref_meta.slice() > 0.0, 1.0, 0.0))
487
    enref.enref_id.set(prediction.enref_id.get())
488
    enref.enref_properties.replace(
489
        np.where(prediction.enref_properties.slice() > 0.0, 1.0, 0.0))
490
    if prediction.enref_properties.is_group() > 0.0:
491
      enref.enref_membership.replace(
492
          np.where(prediction.enref_membership.slice() > 0.0, 1.0, 0.0))
493
    else:
494
      enref.enref_membership.set([])
495
    enref.signals.set([])
496

497
    return enref
498

499
  def build_enrefs_from_predictions(
500
      self, tokens,
501
      predictions,
502
      words,
503
      prev_enrefs):
504
    """Build new enrefs from prediction logits."""
505
    # Identify spans.
506
    spans = []
507
    current_span = None
508
    for i, pred_enc in enumerate(predictions):
509
      if current_span and (pred_enc.enref_meta.is_enref() <= 0.0 or
510
                           current_span[1] != pred_enc.enref_id.get()):
511
        spans.append((current_span[0], i))
512
        current_span = None
513
      if not current_span and pred_enc.enref_meta.is_enref() > 0.0:
514
        current_span = (i, pred_enc.enref_id.get())
515
    if current_span:
516
      spans.append((current_span[0], len(predictions)))
517

518
    # Create enrefs for spans
519
    enrefs = []
520
    for (start, end) in spans:
521
      enref = self.build_enref_from_prediction(tokens[start],
522
                                               predictions[start])
523
      enref.wordvec.set(np.mean([tokens[i].wordvec.get()
524
                                 for i in range(start, end)], 0))
525
      enref.bert.set(np.mean([tokens[i].bert.get()
526
                              for i in range(start, end)], 0))
527
      span_text = ' '.join([words[i] for i in range(start, end)])
528

529
      name = words[start]
530
      if enref.enref_meta.is_new() <= 0.0:
531
        for e in prev_enrefs:
532
          if e.enref_id.get() == enref.enref_id.get():
533
            name = e.entity_name
534
            break
535
      enref.populate(name, (start, end), span_text)
536
      enrefs.append(enref)
537

538
    return enrefs
539

540
  def save(self, path):
541
    """Saves encoding to json file."""
542
    encodings_dict = {
543
        'enref_encoding_layout': [
544
            (n, c.__name__) for (n, c) in self.enref_encoding_layout
545
        ],
546
        'token_encoding_layout': [
547
            (n, c.__name__) for (n, c) in self.token_encoding_layout
548
        ],
549
        'prediction_encoding_layout': [
550
            (n, c.__name__) for (n, c) in self.prediction_encoding_layout
551
        ],
552
    }
553

554
    filepath = os.path.join(path, 'encodings.json')
555
    with tf.io.gfile.GFile(filepath, 'w') as file:
556
      json.dump(encodings_dict, file, indent=2)
557

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.