stable-diffusion-webui

Форк
0
62 строки · 2.1 Кб
1
from modules import modelloader, devices, errors
2
from modules.shared import opts
3
from modules.upscaler import Upscaler, UpscalerData
4
from modules.upscaler_utils import upscale_with_model
5

6

7
class UpscalerESRGAN(Upscaler):
8
    def __init__(self, dirname):
9
        self.name = "ESRGAN"
10
        self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
11
        self.model_name = "ESRGAN_4x"
12
        self.scalers = []
13
        self.user_path = dirname
14
        super().__init__()
15
        model_paths = self.find_models(ext_filter=[".pt", ".pth"])
16
        scalers = []
17
        if len(model_paths) == 0:
18
            scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
19
            scalers.append(scaler_data)
20
        for file in model_paths:
21
            if file.startswith("http"):
22
                name = self.model_name
23
            else:
24
                name = modelloader.friendly_name(file)
25

26
            scaler_data = UpscalerData(name, file, self, 4)
27
            self.scalers.append(scaler_data)
28

29
    def do_upscale(self, img, selected_model):
30
        try:
31
            model = self.load_model(selected_model)
32
        except Exception:
33
            errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
34
            return img
35
        model.to(devices.device_esrgan)
36
        return esrgan_upscale(model, img)
37

38
    def load_model(self, path: str):
39
        if path.startswith("http"):
40
            # TODO: this doesn't use `path` at all?
41
            filename = modelloader.load_file_from_url(
42
                url=self.model_url,
43
                model_dir=self.model_download_path,
44
                file_name=f"{self.model_name}.pth",
45
            )
46
        else:
47
            filename = path
48

49
        return modelloader.load_spandrel_model(
50
            filename,
51
            device=('cpu' if devices.device_esrgan.type == 'mps' else None),
52
            expected_architecture='ESRGAN',
53
        )
54

55

56
def esrgan_upscale(model, img):
57
    return upscale_with_model(
58
        model,
59
        img,
60
        tile_size=opts.ESRGAN_tile,
61
        tile_overlap=opts.ESRGAN_tile_overlap,
62
    )
63

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.