Langchain-Chatchat

Форк
0
/
llm_api_stale.py 
258 строк · 9.3 Кб
1
"""
2
调用示例: python llm_api_stale.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651
3
其他fastchat.server.controller/worker/openai_api_server参数可按照fastchat文档调用
4
但少数非关键参数如--worker-address,--allowed-origins,--allowed-methods,--allowed-headers不支持
5

6
"""
7
import sys
8
import os
9

10
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
11

12
import subprocess
13
import re
14
import logging
15
import argparse
16

17
LOG_PATH = "./logs/"
18
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
19
logger = logging.getLogger()
20
logger.setLevel(logging.INFO)
21
logging.basicConfig(format=LOG_FORMAT)
22

23
parser = argparse.ArgumentParser()
24
# ------multi worker-----------------
25
parser.add_argument('--model-path-address',
26
                    default="THUDM/chatglm2-6b@localhost@20002",
27
                    nargs="+",
28
                    type=str,
29
                    help="model path, host, and port, formatted as model-path@host@port")
30
# ---------------controller-------------------------
31

32
parser.add_argument("--controller-host", type=str, default="localhost")
33
parser.add_argument("--controller-port", type=int, default=21001)
34
parser.add_argument(
35
    "--dispatch-method",
36
    type=str,
37
    choices=["lottery", "shortest_queue"],
38
    default="shortest_queue",
39
)
40
controller_args = ["controller-host", "controller-port", "dispatch-method"]
41

42
# ----------------------worker------------------------------------------
43

44
parser.add_argument("--worker-host", type=str, default="localhost")
45
parser.add_argument("--worker-port", type=int, default=21002)
46
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
47
# parser.add_argument(
48
#     "--controller-address", type=str, default="http://localhost:21001"
49
# )
50
parser.add_argument(
51
    "--model-path",
52
    type=str,
53
    default="lmsys/vicuna-7b-v1.3",
54
    help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
55
)
56
parser.add_argument(
57
    "--revision",
58
    type=str,
59
    default="main",
60
    help="Hugging Face Hub model revision identifier",
61
)
62
parser.add_argument(
63
    "--device",
64
    type=str,
65
    choices=["cpu", "cuda", "mps", "xpu"],
66
    default="cuda",
67
    help="The device type",
68
)
69
parser.add_argument(
70
    "--gpus",
71
    type=str,
72
    default="0",
73
    help="A single GPU like 1 or multiple GPUs like 0,2",
74
)
75
parser.add_argument("--num-gpus", type=int, default=1)
76
parser.add_argument(
77
    "--max-gpu-memory",
78
    type=str,
79
    default="20GiB",
80
    help="The maximum memory per gpu. Use a string like '13Gib'",
81
)
82
parser.add_argument(
83
    "--load-8bit", action="store_true", help="Use 8-bit quantization"
84
)
85
parser.add_argument(
86
    "--cpu-offloading",
87
    action="store_true",
88
    help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
89
)
90
parser.add_argument(
91
    "--gptq-ckpt",
92
    type=str,
93
    default=None,
94
    help="Load quantized model. The path to the local GPTQ checkpoint.",
95
)
96
parser.add_argument(
97
    "--gptq-wbits",
98
    type=int,
99
    default=16,
100
    choices=[2, 3, 4, 8, 16],
101
    help="#bits to use for quantization",
102
)
103
parser.add_argument(
104
    "--gptq-groupsize",
105
    type=int,
106
    default=-1,
107
    help="Groupsize to use for quantization; default uses full row.",
108
)
109
parser.add_argument(
110
    "--gptq-act-order",
111
    action="store_true",
112
    help="Whether to apply the activation order GPTQ heuristic",
113
)
114
parser.add_argument(
115
    "--model-names",
116
    type=lambda s: s.split(","),
117
    help="Optional display comma separated names",
118
)
119
parser.add_argument(
120
    "--limit-worker-concurrency",
121
    type=int,
122
    default=5,
123
    help="Limit the model concurrency to prevent OOM.",
124
)
125
parser.add_argument("--stream-interval", type=int, default=2)
126
parser.add_argument("--no-register", action="store_true")
127

128
worker_args = [
129
    "worker-host", "worker-port",
130
    "model-path", "revision", "device", "gpus", "num-gpus",
131
    "max-gpu-memory", "load-8bit", "cpu-offloading",
132
    "gptq-ckpt", "gptq-wbits", "gptq-groupsize",
133
    "gptq-act-order", "model-names", "limit-worker-concurrency",
134
    "stream-interval", "no-register",
135
    "controller-address", "worker-address"
136
]
137
# -----------------openai server---------------------------
138

139
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
140
parser.add_argument("--server-port", type=int, default=8888, help="port number")
141
parser.add_argument(
142
    "--allow-credentials", action="store_true", help="allow credentials"
143
)
144
# parser.add_argument(
145
#     "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
146
# )
147
# parser.add_argument(
148
#     "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
149
# )
150
# parser.add_argument(
151
#     "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
152
# )
153
parser.add_argument(
154
    "--api-keys",
155
    type=lambda s: s.split(","),
156
    help="Optional list of comma separated API keys",
157
)
158
server_args = ["server-host", "server-port", "allow-credentials", "api-keys",
159
               "controller-address"
160
               ]
161

162
# 0,controller, model_worker, openai_api_server
163
# 1, 命令行选项
164
# 2,LOG_PATH
165
# 3, log的文件名
166
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
167

168
# 0 log_path
169
# ! 1 log的文件名,必须与bash_launch_sh一致
170
# 2 controller, worker, openai_api_server
171
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
172
                        sleep 5s;
173
                        echo "wait {2} running"
174
                done
175
                echo '{2} running' """
176

177

178
def string_args(args, args_list):
179
    """将args中的key转化为字符串"""
180
    args_str = ""
181
    for key, value in args._get_kwargs():
182
        # args._get_kwargs中的key以_为分隔符,先转换,再判断是否在指定的args列表中
183
        key = key.replace("_", "-")
184
        if key not in args_list:
185
            continue
186
        # fastchat中port,host没有前缀,去除前缀
187
        key = key.split("-")[-1] if re.search("port|host", key) else key
188
        if not value:
189
            pass
190
        # 1==True ->  True
191
        elif isinstance(value, bool) and value == True:
192
            args_str += f" --{key} "
193
        elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set):
194
            value = " ".join(value)
195
            args_str += f" --{key} {value} "
196
        else:
197
            args_str += f" --{key} {value} "
198

199
    return args_str
200

201

202
def launch_worker(item, args, worker_args=worker_args):
203
    log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
204
    # 先分割model-path-address,在传到string_args中分析参数
205
    args.model_path, args.worker_host, args.worker_port = item.split("@")
206
    args.worker_address = f"http://{args.worker_host}:{args.worker_port}"
207
    print("*" * 80)
208
    print(f"如长时间未启动,请到{LOG_PATH}{log_name}.log下查看日志")
209
    worker_str_args = string_args(args, worker_args)
210
    print(worker_str_args)
211
    worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}")
212
    worker_check_sh = base_check_sh.format(LOG_PATH, f"worker_{log_name}", "model_worker")
213
    subprocess.run(worker_sh, shell=True, check=True)
214
    subprocess.run(worker_check_sh, shell=True, check=True)
215

216

217
def launch_all(args,
218
               controller_args=controller_args,
219
               worker_args=worker_args,
220
               server_args=server_args
221
               ):
222
    print(f"Launching llm service,logs are located in {LOG_PATH}...")
223
    print(f"开始启动LLM服务,请到{LOG_PATH}下监控各模块日志...")
224
    controller_str_args = string_args(args, controller_args)
225
    controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller")
226
    controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller")
227
    subprocess.run(controller_sh, shell=True, check=True)
228
    subprocess.run(controller_check_sh, shell=True, check=True)
229
    print(f"worker启动时间视设备不同而不同,约需3-10分钟,请耐心等待...")
230
    if isinstance(args.model_path_address, str):
231
        launch_worker(args.model_path_address, args=args, worker_args=worker_args)
232
    else:
233
        for idx, item in enumerate(args.model_path_address):
234
            print(f"开始加载第{idx}个模型:{item}")
235
            launch_worker(item, args=args, worker_args=worker_args)
236

237
    server_str_args = string_args(args, server_args)
238
    server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server")
239
    server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server")
240
    subprocess.run(server_sh, shell=True, check=True)
241
    subprocess.run(server_check_sh, shell=True, check=True)
242
    print("Launching LLM service done!")
243
    print("LLM服务启动完毕。")
244

245

246
if __name__ == "__main__":
247
    args = parser.parse_args()
248
    # 必须要加http//:,否则InvalidSchema: No connection adapters were found
249
    args = argparse.Namespace(**vars(args),
250
                              **{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
251

252
    if args.gpus:
253
        if len(args.gpus.split(",")) < args.num_gpus:
254
            raise ValueError(
255
                f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
256
            )
257
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
258
    launch_all(args=args)
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.