BasicSR

Форк
0
/
inference_ridnet.py 
51 строка · 1.9 Кб
1
import argparse
2
import cv2
3
import glob
4
import numpy as np
5
import os
6
import torch
7
from tqdm import tqdm
8

9
from basicsr.archs.ridnet_arch import RIDNet
10
from basicsr.utils.img_util import img2tensor, tensor2img
11

12
if __name__ == '__main__':
13
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
    parser = argparse.ArgumentParser()
15
    parser.add_argument('--test_path', type=str, default='datasets/denoise/RNI15')
16
    parser.add_argument('--noise_g', type=int, default=25)
17
    parser.add_argument(
18
        '--model_path',
19
        type=str,
20
        default=  # noqa: E251
21
        'experiments/pretrained_models/RIDNet/RIDNet.pth')
22
    args = parser.parse_args()
23
    if args.test_path.endswith('/'):  # solve when path ends with /
24
        args.test_path = args.test_path[:-1]
25
    test_root = os.path.join(args.test_path, f'X{args.noise_g}')
26
    result_root = f'results/RIDNet/{os.path.basename(args.test_path)}'
27
    os.makedirs(result_root, exist_ok=True)
28

29
    # set up the RIDNet
30
    net = RIDNet(3, 64, 3).to(device)
31
    checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
32
    net.load_state_dict(checkpoint)
33
    net.eval()
34

35
    # scan all the jpg and png images
36
    img_list = sorted(glob.glob(os.path.join(test_root, '*.[jp][pn]g')))
37
    pbar = tqdm(total=len(img_list), desc='')
38
    for idx, img_path in enumerate(img_list):
39
        img_name = os.path.basename(img_path).split('.')[0]
40
        pbar.update(1)
41
        pbar.set_description(f'{idx}: {img_name}')
42
        # read image
43
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
44
        img = img2tensor(img, bgr2rgb=True, float32=True).unsqueeze(0).to(device)
45
        # inference
46
        with torch.no_grad():
47
            output = net(img)
48
        # save image
49
        output = tensor2img(output, rgb2bgr=True, out_type=np.uint8, min_max=(0, 255))
50
        save_img_path = os.path.join(result_root, f'{img_name}_x{args.noise_g}_RIDNet.png')
51
        cv2.imwrite(save_img_path, output)
52

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.