pytorch

Форк
0
/
rnn.pyi 
102 строки · 2.5 Кб
1
from typing import (
2
    Any,
3
    Iterable,
4
    NamedTuple,
5
    Optional,
6
    overload,
7
    Sequence,
8
    Tuple,
9
    TypeVar,
10
    Union,
11
)
12

13
from typing_extensions import Self
14

15
from torch import Tensor
16

17
from torch._prims_common import DeviceLikeType
18
from torch.types import _dtype
19

20
class PackedSequence_(NamedTuple):
21
    data: Tensor
22
    batch_sizes: Tensor
23
    sorted_indices: Optional[Tensor]
24
    unsorted_indices: Optional[Tensor]
25

26
def bind(optional: Any, fn: Any): ...
27

28
_T = TypeVar("_T")
29

30
class PackedSequence(PackedSequence_):
31
    def __new__(
32
        cls,
33
        data: Tensor,
34
        batch_sizes: Optional[Tensor] = ...,
35
        sorted_indices: Optional[Tensor] = ...,
36
        unsorted_indices: Optional[Tensor] = ...,
37
    ) -> Self: ...
38
    def pin_memory(self: _T) -> _T: ...
39
    def cuda(self: _T, *args: Any, **kwargs: Any) -> _T: ...
40
    def cpu(self: _T) -> _T: ...
41
    def double(self: _T) -> _T: ...
42
    def float(self: _T) -> _T: ...
43
    def half(self: _T) -> _T: ...
44
    def long(self: _T) -> _T: ...
45
    def int(self: _T) -> _T: ...
46
    def short(self: _T) -> _T: ...
47
    def char(self: _T) -> _T: ...
48
    def byte(self: _T) -> _T: ...
49
    @overload
50
    def to(
51
        self: _T,
52
        dtype: _dtype,
53
        non_blocking: bool = False,
54
        copy: bool = False,
55
    ) -> _T: ...
56
    @overload
57
    def to(
58
        self: _T,
59
        device: Optional[DeviceLikeType] = None,
60
        dtype: Optional[_dtype] = None,
61
        non_blocking: bool = False,
62
        copy: bool = False,
63
    ) -> _T: ...
64
    @overload
65
    def to(
66
        self: _T,
67
        other: Tensor,
68
        non_blocking: bool = False,
69
        copy: bool = False,
70
    ) -> _T: ...
71
    @property
72
    def is_cuda(self) -> bool: ...
73
    def is_pinned(self) -> bool: ...
74

75
def invert_permutation(permutation: Optional[Tensor]): ...
76
def pack_padded_sequence(
77
    input: Tensor,
78
    lengths: Tensor,
79
    batch_first: bool = ...,
80
    enforce_sorted: bool = ...,
81
) -> PackedSequence: ...
82
def pad_packed_sequence(
83
    sequence: PackedSequence,
84
    batch_first: bool = ...,
85
    padding_value: float = ...,
86
    total_length: Optional[int] = ...,
87
) -> Tuple[Tensor, ...]: ...
88
def pad_sequence(
89
    sequences: Union[Tensor, Iterable[Tensor]],
90
    batch_first: bool = False,
91
    padding_value: float = ...,
92
) -> Tensor: ...
93
def pack_sequence(
94
    sequences: Sequence[Tensor],
95
    enforce_sorted: bool = ...,
96
) -> PackedSequence: ...
97
def get_packed_sequence(
98
    data: Tensor,
99
    batch_sizes: Optional[Tensor],
100
    sorted_indices: Optional[Tensor],
101
    unsorted_indices: Optional[Tensor],
102
) -> PackedSequence: ...
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.