llama-index

Форк
0
170 строк · 5.3 Кб
1
"""Discord reader.
2

3
Note: this file is named discord_reader.py to avoid conflicts with the
4
discord.py module.
5

6
"""
7

8
import asyncio
9
import logging
10
import os
11
from typing import List, Optional
12

13
from llama_index.legacy.readers.base import BasePydanticReader
14
from llama_index.legacy.schema import Document
15

16
logger = logging.getLogger(__name__)
17

18

19
async def read_channel(
20
    discord_token: str,
21
    channel_id: int,
22
    limit: Optional[int],
23
    oldest_first: bool,
24
) -> List[Document]:
25
    """Async read channel.
26

27
    Note: This is our hack to create a synchronous interface to the
28
    async discord.py API. We use the `asyncio` module to run
29
    this function with `asyncio.get_event_loop().run_until_complete`.
30

31
    """
32
    import discord
33

34
    messages: List[discord.Message] = []
35

36
    class CustomClient(discord.Client):
37
        async def on_ready(self) -> None:
38
            try:
39
                logger.info(f"{self.user} has connected to Discord!")
40
                channel = client.get_channel(channel_id)
41
                # only work for text channels for now
42
                if not isinstance(channel, discord.TextChannel):
43
                    raise ValueError(
44
                        f"Channel {channel_id} is not a text channel. "
45
                        "Only text channels are supported for now."
46
                    )
47
                # thread_dict maps thread_id to thread
48
                thread_dict = {}
49
                for thread in channel.threads:
50
                    thread_dict[thread.id] = thread
51
                async for msg in channel.history(
52
                    limit=limit, oldest_first=oldest_first
53
                ):
54
                    messages.append(msg)
55
                    if msg.id in thread_dict:
56
                        thread = thread_dict[msg.id]
57
                        async for thread_msg in thread.history(
58
                            limit=limit, oldest_first=oldest_first
59
                        ):
60
                            messages.append(thread_msg)
61
            except Exception as e:
62
                logger.error("Encountered error: " + str(e))
63
            finally:
64
                await self.close()
65

66
    intents = discord.Intents.default()
67
    intents.message_content = True
68
    client = CustomClient(intents=intents)
69
    await client.start(discord_token)
70

71
    ### Wraps each message in a Document containing the text \
72
    # as well as some useful metadata properties.
73
    return [
74
        Document(
75
            text=msg.content,
76
            id_=msg.id,
77
            metadata={
78
                "message_id": msg.id,
79
                "username": msg.author.name,
80
                "created_at": msg.created_at,
81
                "edited_at": msg.edited_at,
82
            },
83
        )
84
        for msg in messages
85
    ]
86

87

88
class DiscordReader(BasePydanticReader):
89
    """Discord reader.
90

91
    Reads conversations from channels.
92

93
    Args:
94
        discord_token (Optional[str]): Discord token. If not provided, we
95
            assume the environment variable `DISCORD_TOKEN` is set.
96

97
    """
98

99
    is_remote: bool = True
100
    discord_token: str
101

102
    def __init__(self, discord_token: Optional[str] = None) -> None:
103
        """Initialize with parameters."""
104
        try:
105
            import discord  # noqa
106
        except ImportError:
107
            raise ImportError(
108
                "`discord.py` package not found, please run `pip install discord.py`"
109
            )
110
        if discord_token is None:
111
            discord_token = os.environ["DISCORD_TOKEN"]
112
            if discord_token is None:
113
                raise ValueError(
114
                    "Must specify `discord_token` or set environment "
115
                    "variable `DISCORD_TOKEN`."
116
                )
117

118
        super().__init__(discord_token=discord_token)
119

120
    @classmethod
121
    def class_name(cls) -> str:
122
        return "DiscordReader"
123

124
    def _read_channel(
125
        self, channel_id: int, limit: Optional[int] = None, oldest_first: bool = True
126
    ) -> List[Document]:
127
        """Read channel."""
128
        return asyncio.get_event_loop().run_until_complete(
129
            read_channel(
130
                self.discord_token, channel_id, limit=limit, oldest_first=oldest_first
131
            )
132
        )
133

134
    def load_data(
135
        self,
136
        channel_ids: List[int],
137
        limit: Optional[int] = None,
138
        oldest_first: bool = True,
139
    ) -> List[Document]:
140
        """Load data from the input directory.
141

142
        Args:
143
            channel_ids (List[int]): List of channel ids to read.
144
            limit (Optional[int]): Maximum number of messages to read.
145
            oldest_first (bool): Whether to read oldest messages first.
146
                Defaults to `True`.
147

148
        Returns:
149
            List[Document]: List of documents.
150

151
        """
152
        results: List[Document] = []
153
        for channel_id in channel_ids:
154
            if not isinstance(channel_id, int):
155
                raise ValueError(
156
                    f"Channel id {channel_id} must be an integer, "
157
                    f"not {type(channel_id)}."
158
                )
159
            channel_documents = self._read_channel(
160
                channel_id, limit=limit, oldest_first=oldest_first
161
            )
162
            results += channel_documents
163
        return results
164

165

166
if __name__ == "__main__":
167
    reader = DiscordReader()
168
    logger.info("initialized reader")
169
    output = reader.load_data(channel_ids=[1057178784895348746], limit=10)
170
    logger.info(output)
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.