stable-diffusion-webui
32 строки · 1.4 Кб
1import importlib
2
3class CondFunc:
4def __new__(cls, orig_func, sub_func, cond_func):
5self = super(CondFunc, cls).__new__(cls)
6if isinstance(orig_func, str):
7func_path = orig_func.split('.')
8for i in range(len(func_path)-1, -1, -1):
9try:
10resolved_obj = importlib.import_module('.'.join(func_path[:i]))
11break
12except ImportError:
13pass
14try:
15for attr_name in func_path[i:-1]:
16resolved_obj = getattr(resolved_obj, attr_name)
17orig_func = getattr(resolved_obj, func_path[-1])
18setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
19except AttributeError:
20print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
21pass
22self.__init__(orig_func, sub_func, cond_func)
23return lambda *args, **kwargs: self(*args, **kwargs)
24def __init__(self, orig_func, sub_func, cond_func):
25self.__orig_func = orig_func
26self.__sub_func = sub_func
27self.__cond_func = cond_func
28def __call__(self, *args, **kwargs):
29if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
30return self.__sub_func(self.__orig_func, *args, **kwargs)
31else:
32return self.__orig_func(*args, **kwargs)
33