google-research
564 строки · 20.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Data representations."""
17
18from __future__ import annotations19
20import collections21import copy22import dataclasses23import json24from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union25
26from dense_representations_for_entity_retrieval.mel.mewsli_x import io_util27
28JsonValue = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]29JsonDict = Dict[str, JsonValue]30JsonList = List[JsonValue]31StrOrPurePath = io_util.StrOrPurePath32
33
34def to_jsonl(obj: JsonDict) -> str:35return json.dumps(obj, ensure_ascii=False)36
37
38@dataclasses.dataclass(frozen=True)39class Span:40"""A [start:end]-span in some external string."""41start: int42end: int43
44def __post_init__(self):45if self.start < 0:46raise ValueError(f"start offset is out of bounds {self}")47if self.end < 0:48raise ValueError(f"end offset is out of bounds {self}")49if self.start >= self.end:50raise ValueError(f"start and end offsets are non-monotonic {self}")51
52@staticmethod53def from_json(json_dict: JsonDict) -> Span:54"""Creates a new Span instance from the given JSON-dictionary."""55return Span(start=json_dict["start"], end=json_dict["end"])56
57def to_json(self) -> JsonDict:58"""Returns instance as JSON-compatible nested dictionary."""59return dict(start=self.start, end=self.end)60
61def validate_offsets_relative_to_context(self, context: str) -> None:62"""Validates the span's offsets relative to a context string."""63if self.start >= len(context):64raise ValueError(65f"start offset in {self} is out of bounds w.r.t. '{context}'")66if self.end > len(context):67raise ValueError(68f"end offset in {self} is out of bounds w.r.t. '{context}'")69
70def locate_in(self, spans: Iterable[Span]) -> Optional[int]:71"""Returns the index of the first span that fully contains `self`.72
73Args:
74spans: The spans to search.
75
76Returns:
77First i such that spans[i].{start,end} covers `self.{start,end}`, or None
78if there is no such span, indicating that `self` either is out of range
79relative to spans or crosses span boundaries.
80"""
81for i, span in enumerate(spans):82# The starts may coincide and the ends may coincide.83if (span.start <= self.start and self.start < span.end and84span.start < self.end and self.end <= span.end):85return i86return None87
88
89@dataclasses.dataclass(frozen=True)90class TextSpan(Span):91"""A text span relative to an external string T, with text=T[start:end]."""92text: str93
94def validate_relative_to_context(self, context: str) -> None:95"""Validates that `self.text` matches the designated span in `context`."""96self.validate_offsets_relative_to_context(context)97ref_text = context[self.start:self.end]98if self.text != ref_text:99raise ValueError(f"{self} does not match against context '{context}': "100f"'{self.text}' != '{ref_text}'")101
102@staticmethod103def from_context(span: Span, context: str) -> TextSpan:104"""Creates a new TextSpan by extracting the given `span` from `context`."""105span.validate_offsets_relative_to_context(context)106return TextSpan(span.start, span.end, text=context[span.start:span.end])107
108@staticmethod109def from_elements(start: int, end: int, context: str) -> TextSpan:110"""Creates a new TextSpan by extracting [start:end] from `context`."""111return TextSpan.from_context(span=Span(start, end), context=context)112
113@staticmethod114def from_json(json_dict: JsonDict) -> TextSpan:115"""Creates a new TextSpan from the given JSON-dictionary."""116return TextSpan(117start=json_dict["start"], end=json_dict["end"], text=json_dict["text"])118
119def to_json(self) -> JsonDict:120"""Returns instance as JSON-compatible nested dictionary."""121return dict(start=self.start, end=self.end, text=self.text)122
123
124@dataclasses.dataclass(frozen=True)125class Entity:126"""An entity and its textual representation.127
128Attributes:
129entity_id: Unique identifier of the entity, e.g. WikiData QID.
130title: A title phrase that names the entity.
131description: A definitional description of the entity that serves as its
132unique textual representation, e.g. taken from the beginning of the
133entity's Wikipedia page.
134sentence_spans: Sentence break annotations for the description, as
135character-level Span objects that index into `description`
136sentences: Sentences extracted from `description` according to
137`sentence_spans`. These TextSpan objects include the actual sentence text
138for added convenience. E.g., the string of the description's first
139sentence is `sentences[0].text`.
140description_language: Primary language code of the description and title,
141matching the Wikipedia language edition from which they were extracted.
142description_url: URL of the page where the description was extracted from.
143"""
144entity_id: str145title: str146description: str147sentence_spans: Tuple[Span, ...]148description_language: str149description_url: str150
151def __post_init__(self):152self.validate()153
154@property155def sentences(self) -> Iterator[TextSpan]:156for span in self.sentence_spans:157yield TextSpan.from_context(span, self.description)158
159def validate(self):160for sentence_span in self.sentence_spans:161sentence_span.validate_offsets_relative_to_context(self.description)162
163@staticmethod164def from_json(json_dict: JsonDict) -> Entity:165"""Creates a new Entity from the given JSON-dictionary."""166return Entity(167entity_id=json_dict["entity_id"],168title=json_dict["title"],169description=json_dict["description"],170description_language=json_dict["description_language"],171description_url=json_dict["description_url"],172sentence_spans=tuple(173Span.from_json(t) for t in json_dict["sentence_spans"]),174)175
176def to_json(self) -> JsonDict:177"""Returns instance as JSON-compatible nested dictionary."""178return dict(179entity_id=self.entity_id,180title=self.title,181description=self.description,182description_language=self.description_language,183description_url=self.description_url,184sentence_spans=[t.to_json() for t in self.sentence_spans],185)186
187
188@dataclasses.dataclass(frozen=True)189class Mention:190"""A single mention of an entity, referring to some external context.191
192Attributes:
193example_id: Unique identifier for the mention instance.
194mention_span: A TextSpan denoting one mention, relative to external context.
195entity_id: ID of the mentioned entity.
196metadata: Optional dictionary of additional information about the instance.
197"""
198example_id: str199mention_span: TextSpan200entity_id: str201metadata: Optional[Dict[str, str]] = None202
203@staticmethod204def from_json(json_dict: JsonDict) -> Mention:205"""Creates a new Mention from the given JSON-dictionary."""206return Mention(207example_id=json_dict["example_id"],208mention_span=TextSpan.from_json(json_dict["mention_span"]),209entity_id=json_dict["entity_id"],210metadata=json_dict.get("metadata"),211)212
213def to_json(self) -> JsonDict:214"""Returns instance as JSON-compatible nested dictionary."""215json_dict = dict(216example_id=self.example_id,217mention_span=self.mention_span.to_json(),218entity_id=self.entity_id,219)220if self.metadata is not None:221json_dict["metadata"] = self.metadata222return json_dict223
224
225@dataclasses.dataclass()226class Context:227"""A document text fragment and metadata.228
229Attributes:
230document_title: Title of the document.
231document_url: URL of the document.
232document_id: An identifier for the document. For a Wikipedia page, this may
233be the associated WikiData QID.
234language: Primary language code of the document.
235text: Original text from the document.
236sentence_spans: Sentence break annotations for the text, as character-level
237Span objects that index into `text`.
238sentences: Sentences extracted from `text` according to `sentence_spans`.
239These TextSpan objects include the actual sentence text for added
240convenience. E.g., the first sentence's string is `sentences[0].text`.
241section_title: Optional title of the section under which `text` appeared.
242"""
243document_title: str244document_url: str245document_id: str246language: str247text: str248sentence_spans: Tuple[Span, ...]249section_title: Optional[str] = None250
251def __post_init__(self):252self.validate()253
254@property255def sentences(self) -> Iterator[TextSpan]:256for span in self.sentence_spans:257yield TextSpan.from_context(span, self.text)258
259def validate(self):260for sentence_span in self.sentence_spans:261sentence_span.validate_offsets_relative_to_context(self.text)262
263@staticmethod264def from_json(json_dict: JsonDict) -> Context:265"""Creates a new Context from the given JSON-dictionary."""266return Context(267document_title=json_dict["document_title"],268section_title=json_dict.get("section_title"),269document_url=json_dict["document_url"],270document_id=json_dict["document_id"],271language=json_dict["language"],272text=json_dict["text"],273sentence_spans=tuple(274Span.from_json(t) for t in json_dict["sentence_spans"]),275)276
277def to_json(self, keep_text: bool = True) -> JsonDict:278"""Returns instance as JSON-compatible nested dictionary."""279json_dict = dict(280document_title=self.document_title,281document_url=self.document_url,282document_id=self.document_id,283language=self.language,284text=self.text if keep_text else "",285sentence_spans=[t.to_json() for t in self.sentence_spans],286)287if self.section_title is not None:288json_dict["section_title"] = self.section_title289return json_dict290
291def truncate(self, focus: int, window_size: int) -> Tuple[int, Context]:292"""Truncates the Context to window_size sentences each side of focus.293
294This seeks to truncate the text and sentence_spans of `self` to
295self.sentence_spans[focus - window_size:focus + window_size + 1].
296
297When there are fewer than window_size sentences available before (after) the
298focus, this attempts to retain additional context sentences after (before)
299the focus.
300
301Args:
302focus: The index of the focus sentence in self.sentence_spans.
303window_size: Number of sentences to retain on each side of the focus.
304
305Returns:
306- c, the number of characters removed from the start of the text, which is
307useful for updating any Mention defined in relation to this Context.
308- new_context, a copy of the Context that is updated to contain the
309truncated text and sentence_spans.
310
311Raises:
312IndexError: if focus is not within the range of self.sentence_spans.
313ValueError: if window_size is negative.
314"""
315if focus < 0 or focus >= len(self.sentence_spans):316raise IndexError(f"Index {focus} invalid for {self.sentence_spans}")317if window_size < 0:318raise ValueError(f"Expected a positive window, but got {window_size}")319
320snt_window = self._get_sentence_window(focus, window_size)321relevant_sentences = self.sentence_spans[snt_window.start:snt_window.end]322
323char_offset = relevant_sentences[0].start324char_end = relevant_sentences[-1].end325new_text = self.text[char_offset:char_end]326
327new_sentences = [328Span(old_sentence.start - char_offset, old_sentence.end - char_offset)329for old_sentence in relevant_sentences330]331new_context = dataclasses.replace(332self, text=new_text, sentence_spans=tuple(new_sentences))333return char_offset, new_context334
335def _get_sentence_window(self, focus: int, window_size: int) -> Span:336"""Gets Span of sentence indices to cover window around the focus index."""337# Add window to the left of focus. If there are fewer sentences before the338# focus sentence, carry over the remainder.339left_index = max(focus - window_size, 0)340remainder_left = window_size - (focus - left_index)341assert remainder_left >= 0, remainder_left342
343# Add window to the right of focus, including carryover. (Note, right_index344# is an inclusive index.) If there are fewer sentences after the focus345# sentence, carry back the remainder.346right_index = min(focus + window_size + remainder_left,347len(self.sentence_spans) - 1)348remainder_right = window_size - (right_index - focus)349
350if remainder_right > 0:351# Extend further leftward.352left_index = max(left_index - remainder_right, 0)353
354return Span(left_index, right_index + 1)355
356
357@dataclasses.dataclass()358class ContextualMentions:359"""Multiple entity mentions in a shared context."""360context: Context361mentions: List[Mention]362
363def __post_init__(self):364self.validate()365
366def validate(self):367self.context.validate()368for mention in self.mentions:369mention.mention_span.validate_relative_to_context(self.context.text)370
371@staticmethod372def from_json(json_dict: JsonDict) -> ContextualMentions:373"""Creates a new ContextualMentions from the given JSON-dictionary."""374return ContextualMentions(375context=Context.from_json(json_dict["context"]),376mentions=[Mention.from_json(m) for m in json_dict["mentions"]],377)378
379def to_json(self, keep_text: bool = True) -> JsonDict:380"""Returns instance as JSON-compatible nested dictionary."""381json_dict = dict(382context=self.context.to_json(keep_text=keep_text),383mentions=[m.to_json() for m in self.mentions],384)385return json_dict386
387def unnest_to_single_mention_per_context(self) -> Iterator[ContextualMention]:388for mention in self.mentions:389yield ContextualMention(390context=copy.deepcopy(self.context), mention=copy.deepcopy(mention))391
392@staticmethod393def nest_mentions_by_shared_context(394contextual_mentions: Iterable[ContextualMention]395) -> Iterator[ContextualMentions]:396"""Inverse of unnest_to_single_mention_per_context."""397contexts = {}398groups = collections.defaultdict(list)399for cm in contextual_mentions:400context = cm.context401key = (context.document_id, context.section_title, context.text)402if key in contexts:403assert contexts[key] == context, key404else:405contexts[key] = context406groups[key].append(cm.mention)407
408for key, mentions in groups.items():409yield ContextualMentions(contexts[key], mentions)410
411
412@dataclasses.dataclass()413class ContextualMention:414"""A single entity mention in context."""415context: Context416mention: Mention417
418def __post_init__(self):419self.validate()420
421def validate(self):422self.context.validate()423self.mention.mention_span.validate_relative_to_context(self.context.text)424
425@staticmethod426def from_json(json_dict: JsonDict) -> ContextualMention:427"""Creates a new ContextualMention from the given JSON-dictionary."""428return ContextualMention(429context=Context.from_json(json_dict["context"]),430mention=Mention.from_json(json_dict["mention"]),431)432
433def to_json(self, keep_text: bool = True) -> JsonDict:434"""Returns instance as JSON-compatible nested dictionary."""435json_dict = dict(436context=self.context.to_json(keep_text=keep_text),437mention=self.mention.to_json(),438)439return json_dict440
441def truncate(self, window_size: int) -> Optional[ContextualMention]:442"""Truncates the context to window_size sentences each side of the mention.443
444Args:
445window_size: Number of sentences to retain on each side of the sentence
446containing the mention. See Context.truncate for more detail.
447
448Returns:
449Returns None if no sentence spans were present or if the mention crosses
450sentence boundaries. Otherwise, returns an update copy of the
451ContextualMention where `.context` contains the truncated text and
452sentences, and the character offsets in `.mention` updated accordingly.
453"""
454focus_snt = self.mention.mention_span.locate_in(self.context.sentence_spans)455if focus_snt is None:456# The context has no sentences or the mention crosses sentence boundaries.457return None458
459offset, new_context = self.context.truncate(460focus=focus_snt, window_size=window_size)461
462# Internal consistency check.463max_valid = window_size * 2 + 1464assert len(new_context.sentence_spans) <= max_valid, (465f"Got {len(new_context.sentence_spans)}>{max_valid} sentences for "466f"window_size={window_size} in truncated Context: {new_context}")467
468new_mention = dataclasses.replace(469self.mention,470mention_span=TextSpan(471start=self.mention.mention_span.start - offset,472end=self.mention.mention_span.end - offset,473text=self.mention.mention_span.text))474return ContextualMention(context=new_context, mention=new_mention)475
476
477@dataclasses.dataclass()478class MentionEntityPair:479"""A ContextualMention paired with the Entity it refers to."""480contextual_mention: ContextualMention481entity: Entity482
483def __post_init__(self):484self.validate()485
486def validate(self):487self.contextual_mention.validate()488self.entity.validate()489
490@staticmethod491def from_json(json_dict: JsonDict) -> MentionEntityPair:492"""Creates a new MentionEntityPair from the given JSON-dictionary."""493return MentionEntityPair(494contextual_mention=ContextualMention.from_json(495json_dict["contextual_mention"]),496entity=Entity.from_json(json_dict["entity"]),497)498
499def to_json(self) -> JsonDict:500"""Returns instance as JSON-compatible nested dictionary."""501json_dict = dict(502contextual_mention=self.contextual_mention.to_json(),503entity=self.entity.to_json(),504)505return json_dict506
507
508SchemaAnyT = TypeVar("SchemaAnyT", ContextualMention, ContextualMentions,509Entity, MentionEntityPair)510SchemaAny = Union[ContextualMention, ContextualMentions, Entity,511MentionEntityPair]512
513EntityOrContext = TypeVar("EntityOrContext", Entity, Context)514
515
516def add_sentence_spans(item: EntityOrContext,517sentence_spans: Iterable[Span]) -> EntityOrContext:518"""Returns a copy of item, adding the given sentence_spans."""519if item.sentence_spans:520raise ValueError(f"sentence_spans already populated: {item}")521return dataclasses.replace(item, sentence_spans=tuple(sentence_spans))522
523
524def load_text(path: StrOrPurePath) -> str:525"""Returns the contents of a text file."""526with io_util.open_file(path, "rt") as input_file:527return input_file.read()528
529
530def load_jsonl_as_dicts(path: StrOrPurePath) -> List[JsonDict]:531"""Returns dict-records from JSONL file (without parsing into dataclasses)."""532with io_util.open_file(path) as input_file:533return [json.loads(line) for line in input_file]534
535
536def load_jsonl(path: StrOrPurePath,537schema_cls: Type[SchemaAnyT]) -> List[SchemaAnyT]:538"""Loads the designated type of schema dataclass items from a JSONL file.539
540Args:
541path: File path to load. Each line in the file is a JSON-serialized object.
542schema_cls: The dataclass to parse into, e.g. `ContextualMention`, `Entity`,
543etc.
544
545Returns:
546A list of validated instances of `schema_cls`, one per input line.
547"""
548result = []549for json_dict in load_jsonl_as_dicts(path):550result.append(schema_cls.from_json(json_dict))551return result552
553
554def write_jsonl(path: StrOrPurePath, items: Iterable[SchemaAny]) -> None:555"""Writes a list of any of the schema dataclass items to JSONL file.556
557Args:
558path: Output file path that will store each item as a JSON-serialized line.
559items: Items to output. Instances of a schema dataclass, e.g.
560`ContextualMention`, `Entity`, etc.
561"""
562with io_util.open_file(path, "wt") as output_file:563for item in items:564print(to_jsonl(item.to_json()), file=output_file)565