text-generation-inference
21 строка · 659.0 Байт
1from text_generation_server.utils.hub import (
2download_weights,
3weight_hub_files,
4weight_files,
5)
6
7from text_generation_server.utils.convert import convert_files
8
9
10def test_convert_files():
11model_id = "bigscience/bloom-560m"
12pt_filenames = weight_hub_files(model_id, extension=".bin")
13local_pt_files = download_weights(pt_filenames, model_id)
14local_st_files = [
15p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
16]
17convert_files(local_pt_files, local_st_files, discard_names=[])
18
19found_st_files = weight_files(model_id)
20
21assert all([p in found_st_files for p in local_st_files])
22