llama-index

Форк
0
49 строк · 1.4 Кб
1
from enum import Enum
2
from typing import TYPE_CHECKING, Union, overload
3

4
import numpy as np
5

6
if TYPE_CHECKING:
7
    import torch
8

9

10
class Pooling(str, Enum):
11
    """Enum of possible pooling choices with pooling behaviors."""
12

13
    CLS = "cls"
14
    MEAN = "mean"
15

16
    def __call__(self, array: np.ndarray) -> np.ndarray:
17
        if self == self.CLS:
18
            return self.cls_pooling(array)
19
        return self.mean_pooling(array)
20

21
    @classmethod
22
    @overload
23
    def cls_pooling(cls, array: np.ndarray) -> np.ndarray:
24
        ...
25

26
    @classmethod
27
    @overload
28
    # TODO: Remove this `type: ignore` after the false positive problem
29
    #  is addressed in mypy: https://github.com/python/mypy/issues/15683 .
30
    def cls_pooling(cls, array: "torch.Tensor") -> "torch.Tensor":  # type: ignore
31
        ...
32

33
    @classmethod
34
    def cls_pooling(
35
        cls, array: "Union[np.ndarray, torch.Tensor]"
36
    ) -> "Union[np.ndarray, torch.Tensor]":
37
        if len(array.shape) == 3:
38
            return array[:, 0]
39
        if len(array.shape) == 2:
40
            return array[0]
41
        raise NotImplementedError(f"Unhandled shape {array.shape}.")
42

43
    @classmethod
44
    def mean_pooling(cls, array: np.ndarray) -> np.ndarray:
45
        if len(array.shape) == 3:
46
            return array.mean(axis=1)
47
        if len(array.shape) == 2:
48
            return array.mean(axis=0)
49
        raise NotImplementedError(f"Unhandled shape {array.shape}.")
50

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.