pytorch

Форк
0
34 строки · 1.1 Кб
1
## @package tools
2
# Module caffe2.python.helpers.tools
3

4

5

6

7

8

9
def image_input(
10
    model, blob_in, blob_out, order="NCHW", use_gpu_transform=False, **kwargs
11
):
12
    assert 'is_test' in kwargs, "Argument 'is_test' is required"
13
    if order == "NCHW":
14
        if (use_gpu_transform):
15
            kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
16
            # GPU transform will handle NHWC -> NCHW
17
            outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
18
            pass
19
        else:
20
            outputs = model.net.ImageInput(
21
                blob_in, [blob_out[0] + '_nhwc'] + blob_out[1:], **kwargs
22
            )
23
            outputs_list = list(outputs)
24
            outputs_list[0] = model.net.NHWC2NCHW(outputs_list[0], blob_out[0])
25
            outputs = tuple(outputs_list)
26
    else:
27
        outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
28
    return outputs
29

30

31
def video_input(model, blob_in, blob_out, **kwargs):
32
    # size of outputs can vary depending on kwargs
33
    outputs = model.net.VideoInput(blob_in, blob_out, **kwargs)
34
    return outputs
35

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.