colossalai
18 строк · 533.0 Байт
1import torch.nn as nn2from transformers.models.opt.modeling_opt import OPTAttention3
4from .opt_attn import XOPTAttention5
6
7def convert_to_xformer_model(model: nn.Module) -> nn.Module:8for module in model.modules():9if isinstance(module, OPTAttention):10module.__class__ = XOPTAttention11return model12
13
14def recover_from_xformer_model(model: nn.Module) -> nn.Module:15for module in model.modules():16if isinstance(module, XOPTAttention):17module.__class__ = OPTAttention18return model19