colossalai

Форк
0
18 строк · 533.0 Байт
1
import torch.nn as nn
2
from transformers.models.opt.modeling_opt import OPTAttention
3

4
from .opt_attn import XOPTAttention
5

6

7
def convert_to_xformer_model(model: nn.Module) -> nn.Module:
8
    for module in model.modules():
9
        if isinstance(module, OPTAttention):
10
            module.__class__ = XOPTAttention
11
    return model
12

13

14
def recover_from_xformer_model(model: nn.Module) -> nn.Module:
15
    for module in model.modules():
16
        if isinstance(module, XOPTAttention):
17
            module.__class__ = OPTAttention
18
    return model
19

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.