dream
142 строки · 5.5 Кб
1import re2from typing import Dict, Sequence, Optional3
4import src.interactive.functions as interactive5
6import schemas7from config import settings8
9POSTPROCESSING_REGEXP = re.compile(r"[^a-zA-Z0-9\- ]|\bnone\b", re.IGNORECASE)10
11
12class COMeTBaseEngine:13def __init__(self, graph, model_path, decoding_algorithm):14self.graph = graph15self.model_path = model_path16self.decoding_algorithm = decoding_algorithm17
18self._opt, self._state_dict = interactive.load_model_file(self.model_path)19self._data_loader, self._text_encoder = interactive.load_data(self.graph, self._opt)20self._sampler = interactive.set_sampler(self._opt, self.decoding_algorithm, self._data_loader)21self._n_ctx = self._calc_n_ctx()22self._n_vocab = len(self._text_encoder.encoder) + self._n_ctx23
24self._model = interactive.make_model(self._opt, self._n_vocab, self._n_ctx, self._state_dict)25self._model.to(device=settings.device)26
27self._input_event_model = None28self._response_model = None29self._annotator_input_model = None30self._annotator_response_model = None31
32@property33def input_event_model(self) -> Optional[schemas.BaseModel]:34return self._input_event_model35
36@property37def response_model(self) -> Optional[schemas.BaseModel]:38return self._response_model39
40@property41def annotator_input_model(self) -> Optional[schemas.BaseModel]:42return self._annotator_input_model43
44@property45def annotator_response_model(self) -> Optional[schemas.BaseModel]:46return self._annotator_response_model47
48@staticmethod49def beams_cleanup(preprocessed_beams):50postprocessed_beams = []51for beam in preprocessed_beams:52postprocessed_beam = re.sub(POSTPROCESSING_REGEXP, "", beam).strip()53if len(postprocessed_beam):54postprocessed_beams.append(postprocessed_beam)55return postprocessed_beams56
57def all_beams_cleanup(self, raw_result, include_beams_key=True):58for relation_or_category in raw_result:59preprocessed_beams = raw_result[relation_or_category].get("beams", [])60if include_beams_key:61raw_result[relation_or_category]["beams"] = self.beams_cleanup(preprocessed_beams)62else:63raw_result[relation_or_category] = self.beams_cleanup(preprocessed_beams)64return raw_result65
66def _calc_n_ctx(self) -> int:67pass68
69def process_request(self, *args, **kwargs):70pass71
72def _get_result(self, *args, **kwargs):73pass74
75def annotator(self, *args, **kwargs):76pass77
78
79class COMeTAtomic(COMeTBaseEngine):80def __init__(self, model_path, decoding_algorithm):81super().__init__(graph="atomic", model_path=model_path, decoding_algorithm=decoding_algorithm)82self._input_event_model = schemas.AtomicInputEventModel83self._response_model = schemas.AtomicResponseModel84
85def _calc_n_ctx(self):86return self._data_loader.max_event + self._data_loader.max_effect87
88def process_request(self, input_event: schemas.AtomicInputEventModel) -> Dict:89return self._get_result(input_event["input"], input_event["category"])90
91def _get_result(self, event: str, category: Sequence[str]) -> Dict:92raw_result = interactive.get_atomic_sequence(93event, self._model, self._sampler, self._data_loader, self._text_encoder, category94)95return self.all_beams_cleanup(raw_result)96
97def annotator(self, *args, **kwargs):98raise NotImplementedError("No annotator for atomic graph is available!")99
100
101class COMeTConceptNet(COMeTBaseEngine):102def __init__(self, model_path, decoding_algorithm):103super().__init__(graph="conceptnet", model_path=model_path, decoding_algorithm=decoding_algorithm)104self._input_event_model = schemas.ConceptNetInputEventModel105self._response_model = schemas.ConceptNetResponseModel106self._annotator_input_model = schemas.ConceptNetAnnotatorEventModel107self._annotator_response_model = schemas.ConceptNetAnnotatorResponseModel108
109def _calc_n_ctx(self):110return self._data_loader.max_e1 + self._data_loader.max_e2 + self._data_loader.max_r111
112def process_request(self, input_event: schemas.ConceptNetInputEventModel) -> Dict:113return self._get_result(input_event["input"], input_event["category"])114
115def _get_result(self, event, category):116raw_result = interactive.get_conceptnet_sequence(117event, self._model, self._sampler, self._data_loader, self._text_encoder, category118)119return self.all_beams_cleanup(raw_result)120
121def annotator(self, input_event: schemas.ConceptNetAnnotatorEventModel):122batch = []123for nounphrases in input_event["nounphrases"]:124result = {}125for nounphrase in nounphrases:126conceptnet_result = self._get_result(nounphrase, input_event["category"])127result[nounphrase] = self.all_beams_cleanup(conceptnet_result, include_beams_key=False)128batch += [result]129return batch130
131
132class COMeTFactory:133def __init__(self, graph):134self.graph = graph135
136def __call__(self, model_path, decoding_algorithm):137if self.graph == "atomic":138return COMeTAtomic(model_path=model_path, decoding_algorithm=decoding_algorithm)139elif self.graph == "conceptnet":140return COMeTConceptNet(model_path=model_path, decoding_algorithm=decoding_algorithm)141else:142raise ValueError(f"Graph {self.graph} does not exist!")143