skypilot
11 строк · 302.0 Байт
1# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
2
3# Need to call this before importing transformers.
4from flash_attn_patch import replace_llama_attn_with_flash_attn5
6replace_llama_attn_with_flash_attn()7
8from train import train9
10if __name__ == "__main__":11train()12