stable-diffusion-webui
241 строка · 9.1 Кб
1from __future__ import annotations2
3import configparser4import os5import threading6import re7
8from modules import shared, errors, cache, scripts9from modules.gitpython_hack import Repo10from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F40111
12
13os.makedirs(extensions_dir, exist_ok=True)14
15
16def active():17if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":18return []19elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":20return [x for x in extensions if x.enabled and x.is_builtin]21else:22return [x for x in extensions if x.enabled]23
24
25class ExtensionMetadata:26filename = "metadata.ini"27config: configparser.ConfigParser28canonical_name: str29requires: list30
31def __init__(self, path, canonical_name):32self.config = configparser.ConfigParser()33
34filepath = os.path.join(path, self.filename)35# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),36# so no need to check whether the file exists beforehand.37try:38self.config.read(filepath)39except Exception:40errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)41
42self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)43self.canonical_name = canonical_name.lower().strip()44
45self.requires = self.get_script_requirements("Requires", "Extension")46
47def get_script_requirements(self, field, section, extra_section=None):48"""reads a list of requirements from the config; field is the name of the field in the ini file,49like Requires or Before, and section is the name of the [section] in the ini file; additionally,
50reads more requirements from [extra_section] if specified."""
51
52x = self.config.get(section, field, fallback='')53
54if extra_section:55x = x + ', ' + self.config.get(extra_section, field, fallback='')56
57return self.parse_list(x.lower())58
59def parse_list(self, text):60"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""61
62if not text:63return []64
65# both "," and " " are accepted as separator66return [x for x in re.split(r"[,\s]+", text.strip()) if x]67
68
69class Extension:70lock = threading.Lock()71cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']72metadata: ExtensionMetadata73
74def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):75self.name = name76self.path = path77self.enabled = enabled78self.status = ''79self.can_update = False80self.is_builtin = is_builtin81self.commit_hash = ''82self.commit_date = None83self.version = ''84self.branch = None85self.remote = None86self.have_info_from_repo = False87self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())88self.canonical_name = metadata.canonical_name89
90def to_dict(self):91return {x: getattr(self, x) for x in self.cached_fields}92
93def from_dict(self, d):94for field in self.cached_fields:95setattr(self, field, d[field])96
97def read_info_from_repo(self):98if self.is_builtin or self.have_info_from_repo:99return100
101def read_from_repo():102with self.lock:103if self.have_info_from_repo:104return105
106self.do_read_info_from_repo()107
108return self.to_dict()109
110try:111d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)112self.from_dict(d)113except FileNotFoundError:114pass115self.status = 'unknown' if self.status == '' else self.status116
117def do_read_info_from_repo(self):118repo = None119try:120if os.path.exists(os.path.join(self.path, ".git")):121repo = Repo(self.path)122except Exception:123errors.report(f"Error reading github repository info from {self.path}", exc_info=True)124
125if repo is None or repo.bare:126self.remote = None127else:128try:129self.remote = next(repo.remote().urls, None)130commit = repo.head.commit131self.commit_date = commit.committed_date132if repo.active_branch:133self.branch = repo.active_branch.name134self.commit_hash = commit.hexsha135self.version = self.commit_hash[:8]136
137except Exception:138errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)139self.remote = None140
141self.have_info_from_repo = True142
143def list_files(self, subdir, extension):144dirpath = os.path.join(self.path, subdir)145if not os.path.isdir(dirpath):146return []147
148res = []149for filename in sorted(os.listdir(dirpath)):150res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))151
152res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]153
154return res155
156def check_updates(self):157repo = Repo(self.path)158for fetch in repo.remote().fetch(dry_run=True):159if fetch.flags != fetch.HEAD_UPTODATE:160self.can_update = True161self.status = "new commits"162return163
164try:165origin = repo.rev_parse('origin')166if repo.head.commit != origin:167self.can_update = True168self.status = "behind HEAD"169return170except Exception:171self.can_update = False172self.status = "unknown (remote error)"173return174
175self.can_update = False176self.status = "latest"177
178def fetch_and_reset_hard(self, commit='origin'):179repo = Repo(self.path)180# Fix: `error: Your local changes to the following files would be overwritten by merge`,181# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.182repo.git.fetch(all=True)183repo.git.reset(commit, hard=True)184self.have_info_from_repo = False185
186
187def list_extensions():188extensions.clear()189
190if shared.cmd_opts.disable_all_extensions:191print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")192elif shared.opts.disable_all_extensions == "all":193print("*** \"Disable all extensions\" option was set, will not load any extensions ***")194elif shared.cmd_opts.disable_extra_extensions:195print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")196elif shared.opts.disable_all_extensions == "extra":197print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")198
199loaded_extensions = {}200
201# scan through extensions directory and load metadata202for dirname in [extensions_builtin_dir, extensions_dir]:203if not os.path.isdir(dirname):204continue205
206for extension_dirname in sorted(os.listdir(dirname)):207path = os.path.join(dirname, extension_dirname)208if not os.path.isdir(path):209continue210
211canonical_name = extension_dirname212metadata = ExtensionMetadata(path, canonical_name)213
214# check for duplicated canonical names215already_loaded_extension = loaded_extensions.get(metadata.canonical_name)216if already_loaded_extension is not None:217errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)218continue219
220is_builtin = dirname == extensions_builtin_dir221extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)222extensions.append(extension)223loaded_extensions[canonical_name] = extension224
225# check for requirements226for extension in extensions:227if not extension.enabled:228continue229
230for req in extension.metadata.requires:231required_extension = loaded_extensions.get(req)232if required_extension is None:233errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)234continue235
236if not required_extension.enabled:237errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)238continue239
240
241extensions: list[Extension] = []242