dream

Форк
0
72 строки · 2.5 Кб
1
# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import time
17
import logging
18

19
import sentry_sdk
20

21
from deeppavlov.core.common.registry import register
22
from deeppavlov.core.models.component import Component
23
from deeppavlov.dataset_iterators.sqlite_iterator import SQLiteDataIterator
24

25
sentry_sdk.init(os.getenv("SENTRY_DSN"))
26
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG)
27
logger = logging.getLogger(__name__)
28

29

30
@register("wiki_sqlite_vocab")
31
class WikiSQLiteVocab(SQLiteDataIterator, Component):
32
    """Get content from SQLite database by document ids.
33

34
    Args:
35
        load_path: a path to local DB file
36
        join_docs: whether to join extracted docs with ' ' or not
37
        shuffle: whether to shuffle data or not
38

39
    Attributes:
40
        join_docs: whether to join extracted docs with ' ' or not
41

42
    """
43

44
    def __init__(self, load_path, join_docs=True, shuffle=False, **kwargs):
45
        SQLiteDataIterator.__init__(self, load_path=load_path, shuffle=shuffle)
46
        self.join_docs = join_docs
47

48
    def __call__(self, doc_ids_batch=None, *args, **kwargs):
49
        """Get the contents of files, stacked by space or as they are.
50

51
        Args:
52
            doc_ids: a batch of lists of ids to get contents for
53

54
        Returns:
55
            a list of contents / list of lists of contents
56
        """
57
        tm_st = time.time()
58
        contents_batch = []
59
        logger.info(f"doc_ids_batch {doc_ids_batch}")
60
        for ids_list in doc_ids_batch:
61
            contents_list = []
62
            for ids in ids_list:
63
                contents = [self.get_doc_content(doc_id) for doc_id in ids]
64
                logger.debug(f"contents {contents}")
65
                if self.join_docs:
66
                    contents = " ".join(contents)
67
                contents_list.append(contents)
68
            contents_batch.append(contents_list)
69
        tm_end = time.time()
70
        logger.debug(f"sqlite vocab time {tm_end - tm_st}")
71

72
        return contents_batch
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.