stable-diffusion-webui
43 строки · 1.6 Кб
1import os2import sys3
4from modules import modelloader, devices5from modules.shared import opts6from modules.upscaler import Upscaler, UpscalerData7from modules.upscaler_utils import upscale_with_model8
9
10class UpscalerHAT(Upscaler):11def __init__(self, dirname):12self.name = "HAT"13self.scalers = []14self.user_path = dirname15super().__init__()16for file in self.find_models(ext_filter=[".pt", ".pth"]):17name = modelloader.friendly_name(file)18scale = 4 # TODO: scale might not be 4, but we can't know without loading the model19scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)20self.scalers.append(scaler_data)21
22def do_upscale(self, img, selected_model):23try:24model = self.load_model(selected_model)25except Exception as e:26print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)27return img28model.to(devices.device_esrgan) # TODO: should probably be device_hat29return upscale_with_model(30model,31img,32tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile33tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap34)35
36def load_model(self, path: str):37if not os.path.isfile(path):38raise FileNotFoundError(f"Model file {path} not found")39return modelloader.load_spandrel_model(40path,41device=devices.device_esrgan, # TODO: should probably be device_hat42expected_architecture='HAT',43)44