longlm
/
modify_utils.py
50 строк · 2.7 Кб
1from types import MethodType
2
3
4def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
5"""
6This function modifies the method of an instance of a model class.
7It's part from chat-GPT.
8It will replace the method with the new method.
9Currently, we only use this function to modify the attention method of a model. Do not test it further.
10
11instance:
12instance of a model to modify.
13target_class_name:
14name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
15new_method: new method to replace the original method. E.g. 'self_extend_forward'.
16It should include a parameter 'self' to be binded to the instance.
17"""
18if visited_instances is None:
19visited_instances = set()
20# Unique identifier for the instance (using id() since object's id is unique)
21instance_id = id(instance)
22if instance_id in visited_instances:
23return
24# Add the instance to the already_visited set
25visited_instances.add(instance_id)
26
27# Check if this instance is of the target class
28if instance.__class__.__name__ == target_class_name:
29bond_method = MethodType(new_method, instance)
30setattr(instance, target_method_name, bond_method)
31elif hasattr(instance, '__dict__'):
32for attr_name, attr_value in instance.__dict__.items():
33if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
34modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
35elif isinstance(attr_value, (list, tuple)):
36for item in attr_value:
37if isinstance(item, object):
38modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
39# If attribute value is a dictionary, iterate over its values and recurse
40# E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
41elif isinstance(attr_value, dict):
42for key, value in attr_value.items():
43if isinstance(value, object):
44modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
45
46# If attribute value is a set, iterate and recurse
47elif isinstance(attr_value, set):
48for item in attr_value:
49if isinstance(item, object):
50modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
51