llama-index
49 строк · 1.4 Кб
1from enum import Enum2from typing import TYPE_CHECKING, Union, overload3
4import numpy as np5
6if TYPE_CHECKING:7import torch8
9
10class Pooling(str, Enum):11"""Enum of possible pooling choices with pooling behaviors."""12
13CLS = "cls"14MEAN = "mean"15
16def __call__(self, array: np.ndarray) -> np.ndarray:17if self == self.CLS:18return self.cls_pooling(array)19return self.mean_pooling(array)20
21@classmethod22@overload23def cls_pooling(cls, array: np.ndarray) -> np.ndarray:24...25
26@classmethod27@overload28# TODO: Remove this `type: ignore` after the false positive problem29# is addressed in mypy: https://github.com/python/mypy/issues/15683 .30def cls_pooling(cls, array: "torch.Tensor") -> "torch.Tensor": # type: ignore31...32
33@classmethod34def cls_pooling(35cls, array: "Union[np.ndarray, torch.Tensor]"36) -> "Union[np.ndarray, torch.Tensor]":37if len(array.shape) == 3:38return array[:, 0]39if len(array.shape) == 2:40return array[0]41raise NotImplementedError(f"Unhandled shape {array.shape}.")42
43@classmethod44def mean_pooling(cls, array: np.ndarray) -> np.ndarray:45if len(array.shape) == 3:46return array.mean(axis=1)47if len(array.shape) == 2:48return array.mean(axis=0)49raise NotImplementedError(f"Unhandled shape {array.shape}.")50