dream

Форк
0
112 строк · 4.7 Кб
1
# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import sqlite3
16
from logging import getLogger
17

18
from deeppavlov.core.common.registry import register
19
from deeppavlov.core.models.component import Component
20
from deeppavlov.core.commands.utils import expand_path
21

22
logger = getLogger(__name__)
23

24

25
@register("wiki_sqlite_vocab")
26
class WikiSQLiteVocab(Component):
27
    def __init__(self, load_path: str, shuffle: bool = False, top_n: int = 2, **kwargs) -> None:
28
        load_path = str(expand_path(load_path))
29
        self.top_n = top_n
30
        logger.info("Connecting to database, path: {}".format(load_path))
31
        try:
32
            self.connect = sqlite3.connect(load_path, check_same_thread=False)
33
        except sqlite3.OperationalError as e:
34
            e.args = e.args + ("Check that DB path exists and is a valid DB file",)
35
            raise e
36
        try:
37
            self.db_name = self.get_db_name()
38
        except TypeError as e:
39
            e.args = e.args + (
40
                "Check that DB path was created correctly and is not empty. "
41
                "Check that a correct dataset_format is passed to the ODQAReader config",
42
            )
43
            raise e
44
        self.doc_ids = self.get_doc_ids()
45
        self.doc2index = self.map_doc2idx()
46

47
    def __call__(self, par_ids_batch, entities_pages_batch, *args, **kwargs):
48
        all_contents, all_contents_ids, all_pages, all_from_linked_page, all_numbers = [], [], [], [], []
49
        for entities_pages, par_ids in zip(entities_pages_batch, par_ids_batch):
50
            page_contents, page_contents_ids, pages, from_linked_page, numbers = [], [], [], [], []
51
            for entity_pages in entities_pages:
52
                for entity_page in entity_pages[: self.top_n]:
53
                    cur_page_contents, cur_page_contents_ids, cur_pages = self.get_page_content(entity_page)
54
                    page_contents += cur_page_contents
55
                    page_contents_ids += cur_page_contents_ids
56
                    pages += cur_pages
57
                    from_linked_page += [True for _ in cur_pages]
58
                    numbers += list(range(len(cur_pages)))
59

60
            par_contents = []
61
            par_pages = []
62
            for par_id in par_ids:
63
                text, page = self.get_paragraph_content(par_id)
64
                par_contents.append(text)
65
                par_pages.append(page)
66
                from_linked_page.append(False)
67
                numbers.append(0)
68
            all_contents.append(page_contents + par_contents)
69
            all_contents_ids.append(page_contents_ids + par_ids)
70
            all_pages.append(pages + par_pages)
71
            all_from_linked_page.append(from_linked_page)
72
            all_numbers.append(numbers)
73

74
        return all_contents, all_contents_ids, all_pages, all_from_linked_page, all_numbers
75

76
    def get_paragraph_content(self, par_id):
77
        cursor = self.connect.cursor()
78
        cursor.execute("SELECT text, doc FROM {} WHERE title = ?".format(self.db_name), (par_id,))
79
        result = cursor.fetchone()
80
        cursor.close()
81
        return result
82

83
    def get_page_content(self, page):
84
        page = page.replace("_", " ")
85
        cursor = self.connect.cursor()
86
        cursor.execute("SELECT text, title FROM {} WHERE doc = ?".format(self.db_name), (page,))
87
        result = cursor.fetchall()
88
        paragraphs = [elem[0] for elem in result]
89
        titles = [elem[1] for elem in result]
90
        pages = [page for _ in result]
91
        cursor.close()
92
        return paragraphs, titles, pages
93

94
    def get_doc_ids(self):
95
        cursor = self.connect.cursor()
96
        cursor.execute("SELECT title FROM {}".format(self.db_name))
97
        ids = [ids[0] for ids in cursor.fetchall()]
98
        cursor.close()
99
        return ids
100

101
    def get_db_name(self):
102
        cursor = self.connect.cursor()
103
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
104
        assert cursor.arraysize == 1
105
        name = cursor.fetchone()[0]
106
        cursor.close()
107
        return name
108

109
    def map_doc2idx(self):
110
        doc2idx = {doc_id: i for i, doc_id in enumerate(self.doc_ids)}
111
        logger.info("SQLite iterator: The size of the database is {} documents".format(len(doc2idx)))
112
        return doc2idx
113

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.