longlm

Форк
0
/
modify_utils.py 
50 строк · 2.7 Кб
1
from types import MethodType
2

3

4
def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
5
    """
6
        This function modifies the method of an instance of a model class. 
7
        It's part from chat-GPT.
8
        It will replace the method  with the new method.
9
        Currently, we only use this function to modify the attention method of a model. Do not test it further. 
10

11
        instance: 
12
            instance of a model to modify.
13
        target_class_name: 
14
            name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
15
        new_method: new method to replace the original method. E.g. 'self_extend_forward'. 
16
            It should include a parameter 'self' to be binded to the instance.
17
    """
18
    if visited_instances is None:
19
        visited_instances = set()
20
    # Unique identifier for the instance (using id() since object's id is unique)
21
    instance_id = id(instance)
22
    if instance_id in visited_instances:
23
        return 
24
    # Add the instance to the already_visited set
25
    visited_instances.add(instance_id)
26

27
    # Check if this instance is of the target class
28
    if instance.__class__.__name__ == target_class_name:
29
        bond_method = MethodType(new_method, instance) 
30
        setattr(instance, target_method_name, bond_method)
31
    elif hasattr(instance, '__dict__'):
32
        for attr_name, attr_value in instance.__dict__.items():
33
            if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
34
                modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
35
            elif isinstance(attr_value, (list, tuple)):
36
                for item in attr_value:
37
                    if isinstance(item, object):
38
                        modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
39
            # If attribute value is a dictionary, iterate over its values and recurse
40
            # E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
41
            elif isinstance(attr_value, dict):
42
                for key, value in attr_value.items():
43
                    if isinstance(value, object):
44
                        modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
45
            
46
            # If attribute value is a set, iterate and recurse
47
            elif isinstance(attr_value, set):
48
                for item in attr_value:
49
                    if isinstance(item, object):
50
                        modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
51
    

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.