dream

Форк
0
142 строки · 5.5 Кб
1
import re
2
from typing import Dict, Sequence, Optional
3

4
import src.interactive.functions as interactive
5

6
import schemas
7
from config import settings
8

9
POSTPROCESSING_REGEXP = re.compile(r"[^a-zA-Z0-9\- ]|\bnone\b", re.IGNORECASE)
10

11

12
class COMeTBaseEngine:
13
    def __init__(self, graph, model_path, decoding_algorithm):
14
        self.graph = graph
15
        self.model_path = model_path
16
        self.decoding_algorithm = decoding_algorithm
17

18
        self._opt, self._state_dict = interactive.load_model_file(self.model_path)
19
        self._data_loader, self._text_encoder = interactive.load_data(self.graph, self._opt)
20
        self._sampler = interactive.set_sampler(self._opt, self.decoding_algorithm, self._data_loader)
21
        self._n_ctx = self._calc_n_ctx()
22
        self._n_vocab = len(self._text_encoder.encoder) + self._n_ctx
23

24
        self._model = interactive.make_model(self._opt, self._n_vocab, self._n_ctx, self._state_dict)
25
        self._model.to(device=settings.device)
26

27
        self._input_event_model = None
28
        self._response_model = None
29
        self._annotator_input_model = None
30
        self._annotator_response_model = None
31

32
    @property
33
    def input_event_model(self) -> Optional[schemas.BaseModel]:
34
        return self._input_event_model
35

36
    @property
37
    def response_model(self) -> Optional[schemas.BaseModel]:
38
        return self._response_model
39

40
    @property
41
    def annotator_input_model(self) -> Optional[schemas.BaseModel]:
42
        return self._annotator_input_model
43

44
    @property
45
    def annotator_response_model(self) -> Optional[schemas.BaseModel]:
46
        return self._annotator_response_model
47

48
    @staticmethod
49
    def beams_cleanup(preprocessed_beams):
50
        postprocessed_beams = []
51
        for beam in preprocessed_beams:
52
            postprocessed_beam = re.sub(POSTPROCESSING_REGEXP, "", beam).strip()
53
            if len(postprocessed_beam):
54
                postprocessed_beams.append(postprocessed_beam)
55
        return postprocessed_beams
56

57
    def all_beams_cleanup(self, raw_result, include_beams_key=True):
58
        for relation_or_category in raw_result:
59
            preprocessed_beams = raw_result[relation_or_category].get("beams", [])
60
            if include_beams_key:
61
                raw_result[relation_or_category]["beams"] = self.beams_cleanup(preprocessed_beams)
62
            else:
63
                raw_result[relation_or_category] = self.beams_cleanup(preprocessed_beams)
64
        return raw_result
65

66
    def _calc_n_ctx(self) -> int:
67
        pass
68

69
    def process_request(self, *args, **kwargs):
70
        pass
71

72
    def _get_result(self, *args, **kwargs):
73
        pass
74

75
    def annotator(self, *args, **kwargs):
76
        pass
77

78

79
class COMeTAtomic(COMeTBaseEngine):
80
    def __init__(self, model_path, decoding_algorithm):
81
        super().__init__(graph="atomic", model_path=model_path, decoding_algorithm=decoding_algorithm)
82
        self._input_event_model = schemas.AtomicInputEventModel
83
        self._response_model = schemas.AtomicResponseModel
84

85
    def _calc_n_ctx(self):
86
        return self._data_loader.max_event + self._data_loader.max_effect
87

88
    def process_request(self, input_event: schemas.AtomicInputEventModel) -> Dict:
89
        return self._get_result(input_event["input"], input_event["category"])
90

91
    def _get_result(self, event: str, category: Sequence[str]) -> Dict:
92
        raw_result = interactive.get_atomic_sequence(
93
            event, self._model, self._sampler, self._data_loader, self._text_encoder, category
94
        )
95
        return self.all_beams_cleanup(raw_result)
96

97
    def annotator(self, *args, **kwargs):
98
        raise NotImplementedError("No annotator for atomic graph is available!")
99

100

101
class COMeTConceptNet(COMeTBaseEngine):
102
    def __init__(self, model_path, decoding_algorithm):
103
        super().__init__(graph="conceptnet", model_path=model_path, decoding_algorithm=decoding_algorithm)
104
        self._input_event_model = schemas.ConceptNetInputEventModel
105
        self._response_model = schemas.ConceptNetResponseModel
106
        self._annotator_input_model = schemas.ConceptNetAnnotatorEventModel
107
        self._annotator_response_model = schemas.ConceptNetAnnotatorResponseModel
108

109
    def _calc_n_ctx(self):
110
        return self._data_loader.max_e1 + self._data_loader.max_e2 + self._data_loader.max_r
111

112
    def process_request(self, input_event: schemas.ConceptNetInputEventModel) -> Dict:
113
        return self._get_result(input_event["input"], input_event["category"])
114

115
    def _get_result(self, event, category):
116
        raw_result = interactive.get_conceptnet_sequence(
117
            event, self._model, self._sampler, self._data_loader, self._text_encoder, category
118
        )
119
        return self.all_beams_cleanup(raw_result)
120

121
    def annotator(self, input_event: schemas.ConceptNetAnnotatorEventModel):
122
        batch = []
123
        for nounphrases in input_event["nounphrases"]:
124
            result = {}
125
            for nounphrase in nounphrases:
126
                conceptnet_result = self._get_result(nounphrase, input_event["category"])
127
                result[nounphrase] = self.all_beams_cleanup(conceptnet_result, include_beams_key=False)
128
            batch += [result]
129
        return batch
130

131

132
class COMeTFactory:
133
    def __init__(self, graph):
134
        self.graph = graph
135

136
    def __call__(self, model_path, decoding_algorithm):
137
        if self.graph == "atomic":
138
            return COMeTAtomic(model_path=model_path, decoding_algorithm=decoding_algorithm)
139
        elif self.graph == "conceptnet":
140
            return COMeTConceptNet(model_path=model_path, decoding_algorithm=decoding_algorithm)
141
        else:
142
            raise ValueError(f"Graph {self.graph} does not exist!")
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.