pytorch

Форк
0
/
convert_parameters.py 
83 строки · 3.1 Кб
1
import torch
2
from typing import Iterable, Optional
3

4

5
def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
6
    r"""Flatten an iterable of parameters into a single vector.
7

8
    Args:
9
        parameters (Iterable[Tensor]): an iterable of Tensors that are the
10
            parameters of a model.
11

12
    Returns:
13
        The parameters represented by a single vector
14
    """
15
    # Flag for the device where the parameter is located
16
    param_device = None
17

18
    vec = []
19
    for param in parameters:
20
        # Ensure the parameters are located in the same device
21
        param_device = _check_param_device(param, param_device)
22

23
        vec.append(param.view(-1))
24
    return torch.cat(vec)
25

26

27
def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None:
28
    r"""Copy slices of a vector into an iterable of parameters.
29

30
    Args:
31
        vec (Tensor): a single vector representing the parameters of a model.
32
        parameters (Iterable[Tensor]): an iterable of Tensors that are the
33
            parameters of a model.
34
    """
35
    # Ensure vec of type Tensor
36
    if not isinstance(vec, torch.Tensor):
37
        raise TypeError(f'expected torch.Tensor, but got: {torch.typename(vec)}')
38
    # Flag for the device where the parameter is located
39
    param_device = None
40

41
    # Pointer for slicing the vector for each parameter
42
    pointer = 0
43
    for param in parameters:
44
        # Ensure the parameters are located in the same device
45
        param_device = _check_param_device(param, param_device)
46

47
        # The length of the parameter
48
        num_param = param.numel()
49
        # Slice the vector, reshape it, and replace the old data of the parameter
50
        param.data = vec[pointer:pointer + num_param].view_as(param).data
51

52
        # Increment the pointer
53
        pointer += num_param
54

55

56
def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int:
57
    r"""Check if the parameters are located on the same device.
58

59
    Currently, the conversion between model parameters and single vector form is not supported
60
    for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1.
61

62
    Args:
63
        param ([Tensor]): a Tensor of a parameter of a model
64
        old_param_device (int): the device where the first parameter of a
65
                                model is allocated.
66

67
    Returns:
68
        old_param_device (int): report device for the first time
69
    """
70
    # Meet the first parameter
71
    support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()]
72
    if old_param_device is None:
73
        old_param_device = param.get_device() if param.device.type in support_device_types else -1
74
    else:
75
        warn = False
76
        if param.device.type in support_device_types:  # Check if in same GPU/PrivateUse1
77
            warn = (param.get_device() != old_param_device)
78
        else:  # Check if in CPU
79
            warn = (old_param_device != -1)
80
        if warn:
81
            raise TypeError('Found two parameters on different devices, '
82
                            'this is currently not supported.')
83
    return old_param_device
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.