stable-diffusion-webui
71 строка · 2.2 Кб
1from __future__ import annotations
2
3import logging
4import os
5
6import torch
7
8from modules import (
9devices,
10errors,
11face_restoration,
12face_restoration_utils,
13modelloader,
14shared,
15)
16
17logger = logging.getLogger(__name__)
18model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
19model_download_name = "GFPGANv1.4.pth"
20gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
21
22
23class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
24def name(self):
25return "GFPGAN"
26
27def get_device(self):
28return devices.device_gfpgan
29
30def load_net(self) -> torch.Module:
31for model_path in modelloader.load_models(
32model_path=self.model_path,
33model_url=model_url,
34command_path=self.model_path,
35download_name=model_download_name,
36ext_filter=['.pth'],
37):
38if 'GFPGAN' in os.path.basename(model_path):
39model = modelloader.load_spandrel_model(
40model_path,
41device=self.get_device(),
42expected_architecture='GFPGAN',
43).model
44model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
45return model
46raise ValueError("No GFPGAN model found")
47
48def restore(self, np_image):
49def restore_face(cropped_face_t):
50assert self.net is not None
51return self.net(cropped_face_t, return_rgb=False)[0]
52
53return self.restore_with_helper(np_image, restore_face)
54
55
56def gfpgan_fix_faces(np_image):
57if gfpgan_face_restorer:
58return gfpgan_face_restorer.restore(np_image)
59logger.warning("GFPGAN face restorer not set up")
60return np_image
61
62
63def setup_model(dirname: str) -> None:
64global gfpgan_face_restorer
65
66try:
67face_restoration_utils.patch_facexlib(dirname)
68gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
69shared.face_restorers.append(gfpgan_face_restorer)
70except Exception:
71errors.report("Error setting up GFPGAN", exc_info=True)
72