skypilot

Форк
0
/
train_flash_attn.py 
11 строк · 302.0 Байт
1
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
2

3
# Need to call this before importing transformers.
4
from flash_attn_patch import replace_llama_attn_with_flash_attn
5

6
replace_llama_attn_with_flash_attn()
7

8
from train import train
9

10
if __name__ == "__main__":
11
    train()
12

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.