llama-index

Форк
0
184 строки · 5.9 Кб
1
"""Notion reader."""
2

3
import logging
4
import os
5
from typing import Any, Dict, List, Optional
6

7
import requests  # type: ignore
8

9
from llama_index.legacy.readers.base import BasePydanticReader
10
from llama_index.legacy.schema import Document
11

12
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
13
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
14
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
15
SEARCH_URL = "https://api.notion.com/v1/search"
16

17
logger = logging.getLogger(__name__)
18

19

20
# TODO: Notion DB reader coming soon!
21
class NotionPageReader(BasePydanticReader):
22
    """Notion Page reader.
23

24
    Reads a set of Notion pages.
25

26
    Args:
27
        integration_token (str): Notion integration token.
28

29
    """
30

31
    is_remote: bool = True
32
    integration_token: str
33
    headers: Dict[str, str]
34

35
    def __init__(
36
        self, integration_token: Optional[str] = None, headers: Optional[Dict] = None
37
    ) -> None:
38
        """Initialize with parameters."""
39
        if integration_token is None:
40
            integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
41
            if integration_token is None:
42
                raise ValueError(
43
                    "Must specify `integration_token` or set environment "
44
                    "variable `NOTION_INTEGRATION_TOKEN`."
45
                )
46

47
        headers = headers or {
48
            "Authorization": "Bearer " + integration_token,
49
            "Content-Type": "application/json",
50
            "Notion-Version": "2022-06-28",
51
        }
52
        super().__init__(integration_token=integration_token, headers=headers)
53

54
    @classmethod
55
    def class_name(cls) -> str:
56
        return "NotionPageReader"
57

58
    def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
59
        """Read a block."""
60
        done = False
61
        result_lines_arr = []
62
        cur_block_id = block_id
63
        while not done:
64
            block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
65
            query_dict: Dict[str, Any] = {}
66

67
            res = requests.request(
68
                "GET", block_url, headers=self.headers, json=query_dict
69
            )
70
            data = res.json()
71

72
            for result in data["results"]:
73
                result_type = result["type"]
74
                result_obj = result[result_type]
75

76
                cur_result_text_arr = []
77
                if "rich_text" in result_obj:
78
                    for rich_text in result_obj["rich_text"]:
79
                        # skip if doesn't have text object
80
                        if "text" in rich_text:
81
                            text = rich_text["text"]["content"]
82
                            prefix = "\t" * num_tabs
83
                            cur_result_text_arr.append(prefix + text)
84

85
                result_block_id = result["id"]
86
                has_children = result["has_children"]
87
                if has_children:
88
                    children_text = self._read_block(
89
                        result_block_id, num_tabs=num_tabs + 1
90
                    )
91
                    cur_result_text_arr.append(children_text)
92

93
                cur_result_text = "\n".join(cur_result_text_arr)
94
                result_lines_arr.append(cur_result_text)
95

96
            if data["next_cursor"] is None:
97
                done = True
98
                break
99
            else:
100
                cur_block_id = data["next_cursor"]
101

102
        return "\n".join(result_lines_arr)
103

104
    def read_page(self, page_id: str) -> str:
105
        """Read a page."""
106
        return self._read_block(page_id)
107

108
    def query_database(
109
        self, database_id: str, query_dict: Dict[str, Any] = {}
110
    ) -> List[str]:
111
        """Get all the pages from a Notion database."""
112
        res = requests.post(
113
            DATABASE_URL_TMPL.format(database_id=database_id),
114
            headers=self.headers,
115
            json=query_dict,
116
        )
117
        data = res.json()
118
        page_ids = []
119
        for result in data["results"]:
120
            page_id = result["id"]
121
            page_ids.append(page_id)
122

123
        return page_ids
124

125
    def search(self, query: str) -> List[str]:
126
        """Search Notion page given a text query."""
127
        done = False
128
        next_cursor: Optional[str] = None
129
        page_ids = []
130
        while not done:
131
            query_dict = {
132
                "query": query,
133
            }
134
            if next_cursor is not None:
135
                query_dict["start_cursor"] = next_cursor
136
            res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
137
            data = res.json()
138
            for result in data["results"]:
139
                page_id = result["id"]
140
                page_ids.append(page_id)
141

142
            if data["next_cursor"] is None:
143
                done = True
144
                break
145
            else:
146
                next_cursor = data["next_cursor"]
147
        return page_ids
148

149
    def load_data(
150
        self, page_ids: List[str] = [], database_id: Optional[str] = None
151
    ) -> List[Document]:
152
        """Load data from the input directory.
153

154
        Args:
155
            page_ids (List[str]): List of page ids to load.
156

157
        Returns:
158
            List[Document]: List of documents.
159

160
        """
161
        if not page_ids and not database_id:
162
            raise ValueError("Must specify either `page_ids` or `database_id`.")
163
        docs = []
164
        if database_id is not None:
165
            # get all the pages in the database
166
            page_ids = self.query_database(database_id)
167
            for page_id in page_ids:
168
                page_text = self.read_page(page_id)
169
                docs.append(
170
                    Document(text=page_text, id_=page_id, metadata={"page_id": page_id})
171
                )
172
        else:
173
            for page_id in page_ids:
174
                page_text = self.read_page(page_id)
175
                docs.append(
176
                    Document(text=page_text, id_=page_id, metadata={"page_id": page_id})
177
                )
178

179
        return docs
180

181

182
if __name__ == "__main__":
183
    reader = NotionPageReader()
184
    logger.info(reader.search("What I"))
185

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.