dream
72 строки · 2.5 Кб
1# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os16import time17import logging18
19import sentry_sdk20
21from deeppavlov.core.common.registry import register22from deeppavlov.core.models.component import Component23from deeppavlov.dataset_iterators.sqlite_iterator import SQLiteDataIterator24
25sentry_sdk.init(os.getenv("SENTRY_DSN"))26logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG)27logger = logging.getLogger(__name__)28
29
30@register("wiki_sqlite_vocab")31class WikiSQLiteVocab(SQLiteDataIterator, Component):32"""Get content from SQLite database by document ids.33
34Args:
35load_path: a path to local DB file
36join_docs: whether to join extracted docs with ' ' or not
37shuffle: whether to shuffle data or not
38
39Attributes:
40join_docs: whether to join extracted docs with ' ' or not
41
42"""
43
44def __init__(self, load_path, join_docs=True, shuffle=False, **kwargs):45SQLiteDataIterator.__init__(self, load_path=load_path, shuffle=shuffle)46self.join_docs = join_docs47
48def __call__(self, doc_ids_batch=None, *args, **kwargs):49"""Get the contents of files, stacked by space or as they are.50
51Args:
52doc_ids: a batch of lists of ids to get contents for
53
54Returns:
55a list of contents / list of lists of contents
56"""
57tm_st = time.time()58contents_batch = []59logger.info(f"doc_ids_batch {doc_ids_batch}")60for ids_list in doc_ids_batch:61contents_list = []62for ids in ids_list:63contents = [self.get_doc_content(doc_id) for doc_id in ids]64logger.debug(f"contents {contents}")65if self.join_docs:66contents = " ".join(contents)67contents_list.append(contents)68contents_batch.append(contents_list)69tm_end = time.time()70logger.debug(f"sqlite vocab time {tm_end - tm_st}")71
72return contents_batch73