llama-index
184 строки · 5.9 Кб
1"""Notion reader."""
2
3import logging4import os5from typing import Any, Dict, List, Optional6
7import requests # type: ignore8
9from llama_index.legacy.readers.base import BasePydanticReader10from llama_index.legacy.schema import Document11
12INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"13BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"14DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"15SEARCH_URL = "https://api.notion.com/v1/search"16
17logger = logging.getLogger(__name__)18
19
20# TODO: Notion DB reader coming soon!
21class NotionPageReader(BasePydanticReader):22"""Notion Page reader.23
24Reads a set of Notion pages.
25
26Args:
27integration_token (str): Notion integration token.
28
29"""
30
31is_remote: bool = True32integration_token: str33headers: Dict[str, str]34
35def __init__(36self, integration_token: Optional[str] = None, headers: Optional[Dict] = None37) -> None:38"""Initialize with parameters."""39if integration_token is None:40integration_token = os.getenv(INTEGRATION_TOKEN_NAME)41if integration_token is None:42raise ValueError(43"Must specify `integration_token` or set environment "44"variable `NOTION_INTEGRATION_TOKEN`."45)46
47headers = headers or {48"Authorization": "Bearer " + integration_token,49"Content-Type": "application/json",50"Notion-Version": "2022-06-28",51}52super().__init__(integration_token=integration_token, headers=headers)53
54@classmethod55def class_name(cls) -> str:56return "NotionPageReader"57
58def _read_block(self, block_id: str, num_tabs: int = 0) -> str:59"""Read a block."""60done = False61result_lines_arr = []62cur_block_id = block_id63while not done:64block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)65query_dict: Dict[str, Any] = {}66
67res = requests.request(68"GET", block_url, headers=self.headers, json=query_dict69)70data = res.json()71
72for result in data["results"]:73result_type = result["type"]74result_obj = result[result_type]75
76cur_result_text_arr = []77if "rich_text" in result_obj:78for rich_text in result_obj["rich_text"]:79# skip if doesn't have text object80if "text" in rich_text:81text = rich_text["text"]["content"]82prefix = "\t" * num_tabs83cur_result_text_arr.append(prefix + text)84
85result_block_id = result["id"]86has_children = result["has_children"]87if has_children:88children_text = self._read_block(89result_block_id, num_tabs=num_tabs + 190)91cur_result_text_arr.append(children_text)92
93cur_result_text = "\n".join(cur_result_text_arr)94result_lines_arr.append(cur_result_text)95
96if data["next_cursor"] is None:97done = True98break99else:100cur_block_id = data["next_cursor"]101
102return "\n".join(result_lines_arr)103
104def read_page(self, page_id: str) -> str:105"""Read a page."""106return self._read_block(page_id)107
108def query_database(109self, database_id: str, query_dict: Dict[str, Any] = {}110) -> List[str]:111"""Get all the pages from a Notion database."""112res = requests.post(113DATABASE_URL_TMPL.format(database_id=database_id),114headers=self.headers,115json=query_dict,116)117data = res.json()118page_ids = []119for result in data["results"]:120page_id = result["id"]121page_ids.append(page_id)122
123return page_ids124
125def search(self, query: str) -> List[str]:126"""Search Notion page given a text query."""127done = False128next_cursor: Optional[str] = None129page_ids = []130while not done:131query_dict = {132"query": query,133}134if next_cursor is not None:135query_dict["start_cursor"] = next_cursor136res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)137data = res.json()138for result in data["results"]:139page_id = result["id"]140page_ids.append(page_id)141
142if data["next_cursor"] is None:143done = True144break145else:146next_cursor = data["next_cursor"]147return page_ids148
149def load_data(150self, page_ids: List[str] = [], database_id: Optional[str] = None151) -> List[Document]:152"""Load data from the input directory.153
154Args:
155page_ids (List[str]): List of page ids to load.
156
157Returns:
158List[Document]: List of documents.
159
160"""
161if not page_ids and not database_id:162raise ValueError("Must specify either `page_ids` or `database_id`.")163docs = []164if database_id is not None:165# get all the pages in the database166page_ids = self.query_database(database_id)167for page_id in page_ids:168page_text = self.read_page(page_id)169docs.append(170Document(text=page_text, id_=page_id, metadata={"page_id": page_id})171)172else:173for page_id in page_ids:174page_text = self.read_page(page_id)175docs.append(176Document(text=page_text, id_=page_id, metadata={"page_id": page_id})177)178
179return docs180
181
182if __name__ == "__main__":183reader = NotionPageReader()184logger.info(reader.search("What I"))185