google-research

Форк
0
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Data representations."""
17

18
from __future__ import annotations
19

20
import collections
21
import copy
22
import dataclasses
23
import json
24
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union
25

26
from dense_representations_for_entity_retrieval.mel.mewsli_x import io_util
27

28
JsonValue = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
29
JsonDict = Dict[str, JsonValue]
30
JsonList = List[JsonValue]
31
StrOrPurePath = io_util.StrOrPurePath
32

33

34
def to_jsonl(obj: JsonDict) -> str:
35
  return json.dumps(obj, ensure_ascii=False)
36

37

38
@dataclasses.dataclass(frozen=True)
39
class Span:
40
  """A [start:end]-span in some external string."""
41
  start: int
42
  end: int
43

44
  def __post_init__(self):
45
    if self.start < 0:
46
      raise ValueError(f"start offset is out of bounds {self}")
47
    if self.end < 0:
48
      raise ValueError(f"end offset is out of bounds {self}")
49
    if self.start >= self.end:
50
      raise ValueError(f"start and end offsets are non-monotonic {self}")
51

52
  @staticmethod
53
  def from_json(json_dict: JsonDict) -> Span:
54
    """Creates a new Span instance from the given JSON-dictionary."""
55
    return Span(start=json_dict["start"], end=json_dict["end"])
56

57
  def to_json(self) -> JsonDict:
58
    """Returns instance as JSON-compatible nested dictionary."""
59
    return dict(start=self.start, end=self.end)
60

61
  def validate_offsets_relative_to_context(self, context: str) -> None:
62
    """Validates the span's offsets relative to a context string."""
63
    if self.start >= len(context):
64
      raise ValueError(
65
          f"start offset in {self} is out of bounds w.r.t. '{context}'")
66
    if self.end > len(context):
67
      raise ValueError(
68
          f"end offset in {self} is out of bounds w.r.t. '{context}'")
69

70
  def locate_in(self, spans: Iterable[Span]) -> Optional[int]:
71
    """Returns the index of the first span that fully contains `self`.
72

73
    Args:
74
      spans: The spans to search.
75

76
    Returns:
77
      First i such that spans[i].{start,end} covers `self.{start,end}`, or None
78
      if there is no such span, indicating that `self` either is out of range
79
      relative to spans or crosses span boundaries.
80
    """
81
    for i, span in enumerate(spans):
82
      # The starts may coincide and the ends may coincide.
83
      if (span.start <= self.start and self.start < span.end and
84
          span.start < self.end and self.end <= span.end):
85
        return i
86
    return None
87

88

89
@dataclasses.dataclass(frozen=True)
90
class TextSpan(Span):
91
  """A text span relative to an external string T, with text=T[start:end]."""
92
  text: str
93

94
  def validate_relative_to_context(self, context: str) -> None:
95
    """Validates that `self.text` matches the designated span in `context`."""
96
    self.validate_offsets_relative_to_context(context)
97
    ref_text = context[self.start:self.end]
98
    if self.text != ref_text:
99
      raise ValueError(f"{self} does not match against context '{context}': "
100
                       f"'{self.text}' != '{ref_text}'")
101

102
  @staticmethod
103
  def from_context(span: Span, context: str) -> TextSpan:
104
    """Creates a new TextSpan by extracting the given `span` from `context`."""
105
    span.validate_offsets_relative_to_context(context)
106
    return TextSpan(span.start, span.end, text=context[span.start:span.end])
107

108
  @staticmethod
109
  def from_elements(start: int, end: int, context: str) -> TextSpan:
110
    """Creates a new TextSpan by extracting [start:end] from `context`."""
111
    return TextSpan.from_context(span=Span(start, end), context=context)
112

113
  @staticmethod
114
  def from_json(json_dict: JsonDict) -> TextSpan:
115
    """Creates a new TextSpan from the given JSON-dictionary."""
116
    return TextSpan(
117
        start=json_dict["start"], end=json_dict["end"], text=json_dict["text"])
118

119
  def to_json(self) -> JsonDict:
120
    """Returns instance as JSON-compatible nested dictionary."""
121
    return dict(start=self.start, end=self.end, text=self.text)
122

123

124
@dataclasses.dataclass(frozen=True)
125
class Entity:
126
  """An entity and its textual representation.
127

128
  Attributes:
129
    entity_id: Unique identifier of the entity, e.g. WikiData QID.
130
    title: A title phrase that names the entity.
131
    description: A definitional description of the entity that serves as its
132
      unique textual representation, e.g. taken from the beginning of the
133
      entity's Wikipedia page.
134
    sentence_spans: Sentence break annotations for the description, as
135
      character-level Span objects that index into `description`
136
    sentences: Sentences extracted from `description` according to
137
      `sentence_spans`. These TextSpan objects include the actual sentence text
138
      for added convenience. E.g., the string of the description's first
139
      sentence is `sentences[0].text`.
140
    description_language: Primary language code of the description and title,
141
      matching the Wikipedia language edition from which they were extracted.
142
    description_url: URL of the page where the description was extracted from.
143
  """
144
  entity_id: str
145
  title: str
146
  description: str
147
  sentence_spans: Tuple[Span, ...]
148
  description_language: str
149
  description_url: str
150

151
  def __post_init__(self):
152
    self.validate()
153

154
  @property
155
  def sentences(self) -> Iterator[TextSpan]:
156
    for span in self.sentence_spans:
157
      yield TextSpan.from_context(span, self.description)
158

159
  def validate(self):
160
    for sentence_span in self.sentence_spans:
161
      sentence_span.validate_offsets_relative_to_context(self.description)
162

163
  @staticmethod
164
  def from_json(json_dict: JsonDict) -> Entity:
165
    """Creates a new Entity from the given JSON-dictionary."""
166
    return Entity(
167
        entity_id=json_dict["entity_id"],
168
        title=json_dict["title"],
169
        description=json_dict["description"],
170
        description_language=json_dict["description_language"],
171
        description_url=json_dict["description_url"],
172
        sentence_spans=tuple(
173
            Span.from_json(t) for t in json_dict["sentence_spans"]),
174
    )
175

176
  def to_json(self) -> JsonDict:
177
    """Returns instance as JSON-compatible nested dictionary."""
178
    return dict(
179
        entity_id=self.entity_id,
180
        title=self.title,
181
        description=self.description,
182
        description_language=self.description_language,
183
        description_url=self.description_url,
184
        sentence_spans=[t.to_json() for t in self.sentence_spans],
185
    )
186

187

188
@dataclasses.dataclass(frozen=True)
189
class Mention:
190
  """A single mention of an entity, referring to some external context.
191

192
  Attributes:
193
    example_id: Unique identifier for the mention instance.
194
    mention_span: A TextSpan denoting one mention, relative to external context.
195
    entity_id: ID of the mentioned entity.
196
    metadata: Optional dictionary of additional information about the instance.
197
  """
198
  example_id: str
199
  mention_span: TextSpan
200
  entity_id: str
201
  metadata: Optional[Dict[str, str]] = None
202

203
  @staticmethod
204
  def from_json(json_dict: JsonDict) -> Mention:
205
    """Creates a new Mention from the given JSON-dictionary."""
206
    return Mention(
207
        example_id=json_dict["example_id"],
208
        mention_span=TextSpan.from_json(json_dict["mention_span"]),
209
        entity_id=json_dict["entity_id"],
210
        metadata=json_dict.get("metadata"),
211
    )
212

213
  def to_json(self) -> JsonDict:
214
    """Returns instance as JSON-compatible nested dictionary."""
215
    json_dict = dict(
216
        example_id=self.example_id,
217
        mention_span=self.mention_span.to_json(),
218
        entity_id=self.entity_id,
219
    )
220
    if self.metadata is not None:
221
      json_dict["metadata"] = self.metadata
222
    return json_dict
223

224

225
@dataclasses.dataclass()
226
class Context:
227
  """A document text fragment and metadata.
228

229
  Attributes:
230
    document_title: Title of the document.
231
    document_url: URL of the document.
232
    document_id: An identifier for the document. For a Wikipedia page, this may
233
      be the associated WikiData QID.
234
    language: Primary language code of the document.
235
    text: Original text from the document.
236
    sentence_spans: Sentence break annotations for the text, as character-level
237
      Span objects that index into `text`.
238
    sentences: Sentences extracted from `text` according to `sentence_spans`.
239
      These TextSpan objects include the actual sentence text for added
240
      convenience. E.g., the first sentence's string is `sentences[0].text`.
241
    section_title: Optional title of the section under which `text` appeared.
242
  """
243
  document_title: str
244
  document_url: str
245
  document_id: str
246
  language: str
247
  text: str
248
  sentence_spans: Tuple[Span, ...]
249
  section_title: Optional[str] = None
250

251
  def __post_init__(self):
252
    self.validate()
253

254
  @property
255
  def sentences(self) -> Iterator[TextSpan]:
256
    for span in self.sentence_spans:
257
      yield TextSpan.from_context(span, self.text)
258

259
  def validate(self):
260
    for sentence_span in self.sentence_spans:
261
      sentence_span.validate_offsets_relative_to_context(self.text)
262

263
  @staticmethod
264
  def from_json(json_dict: JsonDict) -> Context:
265
    """Creates a new Context from the given JSON-dictionary."""
266
    return Context(
267
        document_title=json_dict["document_title"],
268
        section_title=json_dict.get("section_title"),
269
        document_url=json_dict["document_url"],
270
        document_id=json_dict["document_id"],
271
        language=json_dict["language"],
272
        text=json_dict["text"],
273
        sentence_spans=tuple(
274
            Span.from_json(t) for t in json_dict["sentence_spans"]),
275
    )
276

277
  def to_json(self, keep_text: bool = True) -> JsonDict:
278
    """Returns instance as JSON-compatible nested dictionary."""
279
    json_dict = dict(
280
        document_title=self.document_title,
281
        document_url=self.document_url,
282
        document_id=self.document_id,
283
        language=self.language,
284
        text=self.text if keep_text else "",
285
        sentence_spans=[t.to_json() for t in self.sentence_spans],
286
    )
287
    if self.section_title is not None:
288
      json_dict["section_title"] = self.section_title
289
    return json_dict
290

291
  def truncate(self, focus: int, window_size: int) -> Tuple[int, Context]:
292
    """Truncates the Context to window_size sentences each side of focus.
293

294
    This seeks to truncate the text and sentence_spans of `self` to
295
      self.sentence_spans[focus - window_size:focus + window_size + 1].
296

297
    When there are fewer than window_size sentences available before (after) the
298
    focus, this attempts to retain additional context sentences after (before)
299
    the focus.
300

301
    Args:
302
      focus: The index of the focus sentence in self.sentence_spans.
303
      window_size: Number of sentences to retain on each side of the focus.
304

305
    Returns:
306
      - c, the number of characters removed from the start of the text, which is
307
        useful for updating any Mention defined in relation to this Context.
308
      - new_context, a copy of the Context that is updated to contain the
309
        truncated text and sentence_spans.
310

311
    Raises:
312
      IndexError: if focus is not within the range of self.sentence_spans.
313
      ValueError: if window_size is negative.
314
    """
315
    if focus < 0 or focus >= len(self.sentence_spans):
316
      raise IndexError(f"Index {focus} invalid for {self.sentence_spans}")
317
    if window_size < 0:
318
      raise ValueError(f"Expected a positive window, but got {window_size}")
319

320
    snt_window = self._get_sentence_window(focus, window_size)
321
    relevant_sentences = self.sentence_spans[snt_window.start:snt_window.end]
322

323
    char_offset = relevant_sentences[0].start
324
    char_end = relevant_sentences[-1].end
325
    new_text = self.text[char_offset:char_end]
326

327
    new_sentences = [
328
        Span(old_sentence.start - char_offset, old_sentence.end - char_offset)
329
        for old_sentence in relevant_sentences
330
    ]
331
    new_context = dataclasses.replace(
332
        self, text=new_text, sentence_spans=tuple(new_sentences))
333
    return char_offset, new_context
334

335
  def _get_sentence_window(self, focus: int, window_size: int) -> Span:
336
    """Gets Span of sentence indices to cover window around the focus index."""
337
    # Add window to the left of focus. If there are fewer sentences before the
338
    # focus sentence, carry over the remainder.
339
    left_index = max(focus - window_size, 0)
340
    remainder_left = window_size - (focus - left_index)
341
    assert remainder_left >= 0, remainder_left
342

343
    # Add window to the right of focus, including carryover. (Note, right_index
344
    # is an inclusive index.) If there are fewer sentences after the focus
345
    # sentence, carry back the remainder.
346
    right_index = min(focus + window_size + remainder_left,
347
                      len(self.sentence_spans) - 1)
348
    remainder_right = window_size - (right_index - focus)
349

350
    if remainder_right > 0:
351
      # Extend further leftward.
352
      left_index = max(left_index - remainder_right, 0)
353

354
    return Span(left_index, right_index + 1)
355

356

357
@dataclasses.dataclass()
358
class ContextualMentions:
359
  """Multiple entity mentions in a shared context."""
360
  context: Context
361
  mentions: List[Mention]
362

363
  def __post_init__(self):
364
    self.validate()
365

366
  def validate(self):
367
    self.context.validate()
368
    for mention in self.mentions:
369
      mention.mention_span.validate_relative_to_context(self.context.text)
370

371
  @staticmethod
372
  def from_json(json_dict: JsonDict) -> ContextualMentions:
373
    """Creates a new ContextualMentions from the given JSON-dictionary."""
374
    return ContextualMentions(
375
        context=Context.from_json(json_dict["context"]),
376
        mentions=[Mention.from_json(m) for m in json_dict["mentions"]],
377
    )
378

379
  def to_json(self, keep_text: bool = True) -> JsonDict:
380
    """Returns instance as JSON-compatible nested dictionary."""
381
    json_dict = dict(
382
        context=self.context.to_json(keep_text=keep_text),
383
        mentions=[m.to_json() for m in self.mentions],
384
    )
385
    return json_dict
386

387
  def unnest_to_single_mention_per_context(self) -> Iterator[ContextualMention]:
388
    for mention in self.mentions:
389
      yield ContextualMention(
390
          context=copy.deepcopy(self.context), mention=copy.deepcopy(mention))
391

392
  @staticmethod
393
  def nest_mentions_by_shared_context(
394
      contextual_mentions: Iterable[ContextualMention]
395
  ) -> Iterator[ContextualMentions]:
396
    """Inverse of unnest_to_single_mention_per_context."""
397
    contexts = {}
398
    groups = collections.defaultdict(list)
399
    for cm in contextual_mentions:
400
      context = cm.context
401
      key = (context.document_id, context.section_title, context.text)
402
      if key in contexts:
403
        assert contexts[key] == context, key
404
      else:
405
        contexts[key] = context
406
      groups[key].append(cm.mention)
407

408
    for key, mentions in groups.items():
409
      yield ContextualMentions(contexts[key], mentions)
410

411

412
@dataclasses.dataclass()
413
class ContextualMention:
414
  """A single entity mention in context."""
415
  context: Context
416
  mention: Mention
417

418
  def __post_init__(self):
419
    self.validate()
420

421
  def validate(self):
422
    self.context.validate()
423
    self.mention.mention_span.validate_relative_to_context(self.context.text)
424

425
  @staticmethod
426
  def from_json(json_dict: JsonDict) -> ContextualMention:
427
    """Creates a new ContextualMention from the given JSON-dictionary."""
428
    return ContextualMention(
429
        context=Context.from_json(json_dict["context"]),
430
        mention=Mention.from_json(json_dict["mention"]),
431
    )
432

433
  def to_json(self, keep_text: bool = True) -> JsonDict:
434
    """Returns instance as JSON-compatible nested dictionary."""
435
    json_dict = dict(
436
        context=self.context.to_json(keep_text=keep_text),
437
        mention=self.mention.to_json(),
438
    )
439
    return json_dict
440

441
  def truncate(self, window_size: int) -> Optional[ContextualMention]:
442
    """Truncates the context to window_size sentences each side of the mention.
443

444
    Args:
445
      window_size: Number of sentences to retain on each side of the sentence
446
        containing the mention. See Context.truncate for more detail.
447

448
    Returns:
449
      Returns None if no sentence spans were present or if the mention crosses
450
      sentence boundaries. Otherwise, returns an update copy of the
451
      ContextualMention where `.context` contains the truncated text and
452
      sentences, and the character offsets in `.mention` updated accordingly.
453
    """
454
    focus_snt = self.mention.mention_span.locate_in(self.context.sentence_spans)
455
    if focus_snt is None:
456
      # The context has no sentences or the mention crosses sentence boundaries.
457
      return None
458

459
    offset, new_context = self.context.truncate(
460
        focus=focus_snt, window_size=window_size)
461

462
    # Internal consistency check.
463
    max_valid = window_size * 2 + 1
464
    assert len(new_context.sentence_spans) <= max_valid, (
465
        f"Got {len(new_context.sentence_spans)}>{max_valid} sentences for "
466
        f"window_size={window_size} in truncated Context: {new_context}")
467

468
    new_mention = dataclasses.replace(
469
        self.mention,
470
        mention_span=TextSpan(
471
            start=self.mention.mention_span.start - offset,
472
            end=self.mention.mention_span.end - offset,
473
            text=self.mention.mention_span.text))
474
    return ContextualMention(context=new_context, mention=new_mention)
475

476

477
@dataclasses.dataclass()
478
class MentionEntityPair:
479
  """A ContextualMention paired with the Entity it refers to."""
480
  contextual_mention: ContextualMention
481
  entity: Entity
482

483
  def __post_init__(self):
484
    self.validate()
485

486
  def validate(self):
487
    self.contextual_mention.validate()
488
    self.entity.validate()
489

490
  @staticmethod
491
  def from_json(json_dict: JsonDict) -> MentionEntityPair:
492
    """Creates a new MentionEntityPair from the given JSON-dictionary."""
493
    return MentionEntityPair(
494
        contextual_mention=ContextualMention.from_json(
495
            json_dict["contextual_mention"]),
496
        entity=Entity.from_json(json_dict["entity"]),
497
    )
498

499
  def to_json(self) -> JsonDict:
500
    """Returns instance as JSON-compatible nested dictionary."""
501
    json_dict = dict(
502
        contextual_mention=self.contextual_mention.to_json(),
503
        entity=self.entity.to_json(),
504
    )
505
    return json_dict
506

507

508
SchemaAnyT = TypeVar("SchemaAnyT", ContextualMention, ContextualMentions,
509
                     Entity, MentionEntityPair)
510
SchemaAny = Union[ContextualMention, ContextualMentions, Entity,
511
                  MentionEntityPair]
512

513
EntityOrContext = TypeVar("EntityOrContext", Entity, Context)
514

515

516
def add_sentence_spans(item: EntityOrContext,
517
                       sentence_spans: Iterable[Span]) -> EntityOrContext:
518
  """Returns a copy of item, adding the given sentence_spans."""
519
  if item.sentence_spans:
520
    raise ValueError(f"sentence_spans already populated: {item}")
521
  return dataclasses.replace(item, sentence_spans=tuple(sentence_spans))
522

523

524
def load_text(path: StrOrPurePath) -> str:
525
  """Returns the contents of a text file."""
526
  with io_util.open_file(path, "rt") as input_file:
527
    return input_file.read()
528

529

530
def load_jsonl_as_dicts(path: StrOrPurePath) -> List[JsonDict]:
531
  """Returns dict-records from JSONL file (without parsing into dataclasses)."""
532
  with io_util.open_file(path) as input_file:
533
    return [json.loads(line) for line in input_file]
534

535

536
def load_jsonl(path: StrOrPurePath,
537
               schema_cls: Type[SchemaAnyT]) -> List[SchemaAnyT]:
538
  """Loads the designated type of schema dataclass items from a JSONL file.
539

540
  Args:
541
    path: File path to load. Each line in the file is a JSON-serialized object.
542
    schema_cls: The dataclass to parse into, e.g. `ContextualMention`, `Entity`,
543
      etc.
544

545
  Returns:
546
    A list of validated instances of `schema_cls`, one per input line.
547
  """
548
  result = []
549
  for json_dict in load_jsonl_as_dicts(path):
550
    result.append(schema_cls.from_json(json_dict))
551
  return result
552

553

554
def write_jsonl(path: StrOrPurePath, items: Iterable[SchemaAny]) -> None:
555
  """Writes a list of any of the schema dataclass items to JSONL file.
556

557
  Args:
558
    path: Output file path that will store each item as a JSON-serialized line.
559
    items: Items to output. Instances of a schema dataclass, e.g.
560
      `ContextualMention`, `Entity`, etc.
561
  """
562
  with io_util.open_file(path, "wt") as output_file:
563
    for item in items:
564
      print(to_jsonl(item.to_json()), file=output_file)
565

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.