Langchain-Chatchat

Форк
0
/
startup.py 
896 строк · 30.9 Кб
1
import asyncio
2
import multiprocessing as mp
3
import os
4
import subprocess
5
import sys
6
from multiprocessing import Process
7
from datetime import datetime
8
from pprint import pprint
9
from langchain_core._api import deprecated
10

11
try:
12
    import numexpr
13

14
    n_cores = numexpr.utils.detect_number_of_cores()
15
    os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
16
except:
17
    pass
18

19
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
20
from configs import (
21
    LOG_PATH,
22
    log_verbose,
23
    logger,
24
    LLM_MODELS,
25
    EMBEDDING_MODEL,
26
    TEXT_SPLITTER_NAME,
27
    FSCHAT_CONTROLLER,
28
    FSCHAT_OPENAI_API,
29
    FSCHAT_MODEL_WORKERS,
30
    API_SERVER,
31
    WEBUI_SERVER,
32
    HTTPX_DEFAULT_TIMEOUT,
33
)
34
from server.utils import (fschat_controller_address, fschat_model_worker_address,
35
                          fschat_openai_api_address, get_httpx_client, get_model_worker_config,
36
                          MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
37
from server.knowledge_base.migrate import create_tables
38
import argparse
39
from typing import List, Dict
40
from configs import VERSION
41

42

43
@deprecated(
44
    since="0.3.0",
45
    message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
46
    removal="0.3.0")
47
def create_controller_app(
48
        dispatch_method: str,
49
        log_level: str = "INFO",
50
) -> FastAPI:
51
    import fastchat.constants
52
    fastchat.constants.LOGDIR = LOG_PATH
53
    from fastchat.serve.controller import app, Controller, logger
54
    logger.setLevel(log_level)
55

56
    controller = Controller(dispatch_method)
57
    sys.modules["fastchat.serve.controller"].controller = controller
58

59
    MakeFastAPIOffline(app)
60
    app.title = "FastChat Controller"
61
    app._controller = controller
62
    return app
63

64

65
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
66
    """
67
    kwargs包含的字段如下:
68
    host:
69
    port:
70
    model_names:[`model_name`]
71
    controller_address:
72
    worker_address:
73

74
    对于Langchain支持的模型:
75
        langchain_model:True
76
        不会使用fschat
77
    对于online_api:
78
        online_api:True
79
        worker_class: `provider`
80
    对于离线模型:
81
        model_path: `model_name_or_path`,huggingface的repo-id或本地路径
82
        device:`LLM_DEVICE`
83
    """
84
    import fastchat.constants
85
    fastchat.constants.LOGDIR = LOG_PATH
86
    import argparse
87

88
    parser = argparse.ArgumentParser()
89
    args = parser.parse_args([])
90

91
    for k, v in kwargs.items():
92
        setattr(args, k, v)
93
    if worker_class := kwargs.get("langchain_model"):  # Langchian支持的模型不用做操作
94
        from fastchat.serve.base_model_worker import app
95
        worker = ""
96
    # 在线模型API
97
    elif worker_class := kwargs.get("worker_class"):
98
        from fastchat.serve.base_model_worker import app
99

100
        worker = worker_class(model_names=args.model_names,
101
                              controller_addr=args.controller_address,
102
                              worker_addr=args.worker_address)
103
        # sys.modules["fastchat.serve.base_model_worker"].worker = worker
104
        sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
105
    # 本地模型
106
    else:
107
        from configs.model_config import VLLM_MODEL_DICT
108
        if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
109
            import fastchat.serve.vllm_worker
110
            from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
111
            from vllm import AsyncLLMEngine
112
            from vllm.engine.arg_utils import AsyncEngineArgs
113

114
            args.tokenizer = args.model_path
115
            args.tokenizer_mode = 'auto'
116
            args.trust_remote_code = True
117
            args.download_dir = None
118
            args.load_format = 'auto'
119
            args.dtype = 'auto'
120
            args.seed = 0
121
            args.worker_use_ray = False
122
            args.pipeline_parallel_size = 1
123
            args.tensor_parallel_size = 1
124
            args.block_size = 16
125
            args.swap_space = 4  # GiB
126
            args.gpu_memory_utilization = 0.90
127
            args.max_num_batched_tokens = None  # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
128
            args.max_num_seqs = 256
129
            args.disable_log_stats = False
130
            args.conv_template = None
131
            args.limit_worker_concurrency = 5
132
            args.no_register = False
133
            args.num_gpus = 1  # vllm worker的切分是tensor并行,这里填写显卡的数量
134
            args.engine_use_ray = False
135
            args.disable_log_requests = False
136

137
            # 0.2.1 vllm后要加的参数, 但是这里不需要
138
            args.max_model_len = None
139
            args.revision = None
140
            args.quantization = None
141
            args.max_log_len = None
142
            args.tokenizer_revision = None
143

144
            # 0.2.2 vllm需要新加的参数
145
            args.max_paddings = 256
146

147
            if args.model_path:
148
                args.model = args.model_path
149
            if args.num_gpus > 1:
150
                args.tensor_parallel_size = args.num_gpus
151

152
            for k, v in kwargs.items():
153
                setattr(args, k, v)
154

155
            engine_args = AsyncEngineArgs.from_cli_args(args)
156
            engine = AsyncLLMEngine.from_engine_args(engine_args)
157

158
            worker = VLLMWorker(
159
                controller_addr=args.controller_address,
160
                worker_addr=args.worker_address,
161
                worker_id=worker_id,
162
                model_path=args.model_path,
163
                model_names=args.model_names,
164
                limit_worker_concurrency=args.limit_worker_concurrency,
165
                no_register=args.no_register,
166
                llm_engine=engine,
167
                conv_template=args.conv_template,
168
            )
169
            sys.modules["fastchat.serve.vllm_worker"].engine = engine
170
            sys.modules["fastchat.serve.vllm_worker"].worker = worker
171
            sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
172

173
        else:
174
            from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
175

176
            args.gpus = "0"  # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
177
            args.max_gpu_memory = "22GiB"
178
            args.num_gpus = 1  # model worker的切分是model并行,这里填写显卡的数量
179

180
            args.load_8bit = False
181
            args.cpu_offloading = None
182
            args.gptq_ckpt = None
183
            args.gptq_wbits = 16
184
            args.gptq_groupsize = -1
185
            args.gptq_act_order = False
186
            args.awq_ckpt = None
187
            args.awq_wbits = 16
188
            args.awq_groupsize = -1
189
            args.model_names = [""]
190
            args.conv_template = None
191
            args.limit_worker_concurrency = 5
192
            args.stream_interval = 2
193
            args.no_register = False
194
            args.embed_in_truncate = False
195
            for k, v in kwargs.items():
196
                setattr(args, k, v)
197
            if args.gpus:
198
                if args.num_gpus is None:
199
                    args.num_gpus = len(args.gpus.split(','))
200
                if len(args.gpus.split(",")) < args.num_gpus:
201
                    raise ValueError(
202
                        f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
203
                    )
204
                os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
205
            gptq_config = GptqConfig(
206
                ckpt=args.gptq_ckpt or args.model_path,
207
                wbits=args.gptq_wbits,
208
                groupsize=args.gptq_groupsize,
209
                act_order=args.gptq_act_order,
210
            )
211
            awq_config = AWQConfig(
212
                ckpt=args.awq_ckpt or args.model_path,
213
                wbits=args.awq_wbits,
214
                groupsize=args.awq_groupsize,
215
            )
216

217
            worker = ModelWorker(
218
                controller_addr=args.controller_address,
219
                worker_addr=args.worker_address,
220
                worker_id=worker_id,
221
                model_path=args.model_path,
222
                model_names=args.model_names,
223
                limit_worker_concurrency=args.limit_worker_concurrency,
224
                no_register=args.no_register,
225
                device=args.device,
226
                num_gpus=args.num_gpus,
227
                max_gpu_memory=args.max_gpu_memory,
228
                load_8bit=args.load_8bit,
229
                cpu_offloading=args.cpu_offloading,
230
                gptq_config=gptq_config,
231
                awq_config=awq_config,
232
                stream_interval=args.stream_interval,
233
                conv_template=args.conv_template,
234
                embed_in_truncate=args.embed_in_truncate,
235
            )
236
            sys.modules["fastchat.serve.model_worker"].args = args
237
            sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
238
            # sys.modules["fastchat.serve.model_worker"].worker = worker
239
            sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
240

241
    MakeFastAPIOffline(app)
242
    app.title = f"FastChat LLM Server ({args.model_names[0]})"
243
    app._worker = worker
244
    return app
245

246

247
def create_openai_api_app(
248
        controller_address: str,
249
        api_keys: List = [],
250
        log_level: str = "INFO",
251
) -> FastAPI:
252
    import fastchat.constants
253
    fastchat.constants.LOGDIR = LOG_PATH
254
    from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
255
    from fastchat.utils import build_logger
256
    logger = build_logger("openai_api", "openai_api.log")
257
    logger.setLevel(log_level)
258

259
    app.add_middleware(
260
        CORSMiddleware,
261
        allow_credentials=True,
262
        allow_origins=["*"],
263
        allow_methods=["*"],
264
        allow_headers=["*"],
265
    )
266

267
    sys.modules["fastchat.serve.openai_api_server"].logger = logger
268
    app_settings.controller_address = controller_address
269
    app_settings.api_keys = api_keys
270

271
    MakeFastAPIOffline(app)
272
    app.title = "FastChat OpeanAI API Server"
273
    return app
274

275

276
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
277
    @app.on_event("startup")
278
    async def on_startup():
279
        if started_event is not None:
280
            started_event.set()
281

282

283
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
284
    import uvicorn
285
    import httpx
286
    from fastapi import Body
287
    import time
288
    import sys
289
    from server.utils import set_httpx_config
290
    set_httpx_config()
291

292
    app = create_controller_app(
293
        dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
294
        log_level=log_level,
295
    )
296
    _set_app_event(app, started_event)
297

298
    # add interface to release and load model worker
299
    @app.post("/release_worker")
300
    def release_worker(
301
            model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
302
            # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
303
            new_model_name: str = Body(None, description="释放后加载该模型"),
304
            keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
305
    ) -> Dict:
306
        available_models = app._controller.list_models()
307
        if new_model_name in available_models:
308
            msg = f"要切换的LLM模型 {new_model_name} 已经存在"
309
            logger.info(msg)
310
            return {"code": 500, "msg": msg}
311

312
        if new_model_name:
313
            logger.info(f"开始切换LLM模型:从 {model_name}{new_model_name}")
314
        else:
315
            logger.info(f"即将停止LLM模型: {model_name}")
316

317
        if model_name not in available_models:
318
            msg = f"the model {model_name} is not available"
319
            logger.error(msg)
320
            return {"code": 500, "msg": msg}
321

322
        worker_address = app._controller.get_worker_address(model_name)
323
        if not worker_address:
324
            msg = f"can not find model_worker address for {model_name}"
325
            logger.error(msg)
326
            return {"code": 500, "msg": msg}
327

328
        with get_httpx_client() as client:
329
            r = client.post(worker_address + "/release",
330
                            json={"new_model_name": new_model_name, "keep_origin": keep_origin})
331
            if r.status_code != 200:
332
                msg = f"failed to release model: {model_name}"
333
                logger.error(msg)
334
                return {"code": 500, "msg": msg}
335

336
        if new_model_name:
337
            timer = HTTPX_DEFAULT_TIMEOUT  # wait for new model_worker register
338
            while timer > 0:
339
                models = app._controller.list_models()
340
                if new_model_name in models:
341
                    break
342
                time.sleep(1)
343
                timer -= 1
344
            if timer > 0:
345
                msg = f"sucess change model from {model_name} to {new_model_name}"
346
                logger.info(msg)
347
                return {"code": 200, "msg": msg}
348
            else:
349
                msg = f"failed change model from {model_name} to {new_model_name}"
350
                logger.error(msg)
351
                return {"code": 500, "msg": msg}
352
        else:
353
            msg = f"sucess to release model: {model_name}"
354
            logger.info(msg)
355
            return {"code": 200, "msg": msg}
356

357
    host = FSCHAT_CONTROLLER["host"]
358
    port = FSCHAT_CONTROLLER["port"]
359

360
    if log_level == "ERROR":
361
        sys.stdout = sys.__stdout__
362
        sys.stderr = sys.__stderr__
363

364
    uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
365

366

367
def run_model_worker(
368
        model_name: str = LLM_MODELS[0],
369
        controller_address: str = "",
370
        log_level: str = "INFO",
371
        q: mp.Queue = None,
372
        started_event: mp.Event = None,
373
):
374
    import uvicorn
375
    from fastapi import Body
376
    import sys
377
    from server.utils import set_httpx_config
378
    set_httpx_config()
379

380
    kwargs = get_model_worker_config(model_name)
381
    host = kwargs.pop("host")
382
    port = kwargs.pop("port")
383
    kwargs["model_names"] = [model_name]
384
    kwargs["controller_address"] = controller_address or fschat_controller_address()
385
    kwargs["worker_address"] = fschat_model_worker_address(model_name)
386
    model_path = kwargs.get("model_path", "")
387
    kwargs["model_path"] = model_path
388

389
    app = create_model_worker_app(log_level=log_level, **kwargs)
390
    _set_app_event(app, started_event)
391
    if log_level == "ERROR":
392
        sys.stdout = sys.__stdout__
393
        sys.stderr = sys.__stderr__
394

395
    # add interface to release and load model
396
    @app.post("/release")
397
    def release_model(
398
            new_model_name: str = Body(None, description="释放后加载该模型"),
399
            keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
400
    ) -> Dict:
401
        if keep_origin:
402
            if new_model_name:
403
                q.put([model_name, "start", new_model_name])
404
        else:
405
            if new_model_name:
406
                q.put([model_name, "replace", new_model_name])
407
            else:
408
                q.put([model_name, "stop", None])
409
        return {"code": 200, "msg": "done"}
410

411
    uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
412

413

414
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
415
    import uvicorn
416
    import sys
417
    from server.utils import set_httpx_config
418
    set_httpx_config()
419

420
    controller_addr = fschat_controller_address()
421
    app = create_openai_api_app(controller_addr, log_level=log_level)
422
    _set_app_event(app, started_event)
423

424
    host = FSCHAT_OPENAI_API["host"]
425
    port = FSCHAT_OPENAI_API["port"]
426
    if log_level == "ERROR":
427
        sys.stdout = sys.__stdout__
428
        sys.stderr = sys.__stderr__
429
    uvicorn.run(app, host=host, port=port)
430

431

432
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
433
    from server.api import create_app
434
    import uvicorn
435
    from server.utils import set_httpx_config
436
    set_httpx_config()
437

438
    app = create_app(run_mode=run_mode)
439
    _set_app_event(app, started_event)
440

441
    host = API_SERVER["host"]
442
    port = API_SERVER["port"]
443

444
    uvicorn.run(app, host=host, port=port)
445

446

447
def run_webui(started_event: mp.Event = None, run_mode: str = None):
448
    from server.utils import set_httpx_config
449
    set_httpx_config()
450

451
    host = WEBUI_SERVER["host"]
452
    port = WEBUI_SERVER["port"]
453

454
    cmd = ["streamlit", "run", "webui.py",
455
           "--server.address", host,
456
           "--server.port", str(port),
457
           "--theme.base", "light",
458
           "--theme.primaryColor", "#165dff",
459
           "--theme.secondaryBackgroundColor", "#f5f5f5",
460
           "--theme.textColor", "#000000",
461
           ]
462
    if run_mode == "lite":
463
        cmd += [
464
            "--",
465
            "lite",
466
        ]
467
    p = subprocess.Popen(cmd)
468
    started_event.set()
469
    p.wait()
470

471

472
def parse_args() -> argparse.ArgumentParser:
473
    parser = argparse.ArgumentParser()
474
    parser.add_argument(
475
        "-a",
476
        "--all-webui",
477
        action="store_true",
478
        help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
479
        dest="all_webui",
480
    )
481
    parser.add_argument(
482
        "--all-api",
483
        action="store_true",
484
        help="run fastchat's controller/openai_api/model_worker servers, run api.py",
485
        dest="all_api",
486
    )
487
    parser.add_argument(
488
        "--llm-api",
489
        action="store_true",
490
        help="run fastchat's controller/openai_api/model_worker servers",
491
        dest="llm_api",
492
    )
493
    parser.add_argument(
494
        "-o",
495
        "--openai-api",
496
        action="store_true",
497
        help="run fastchat's controller/openai_api servers",
498
        dest="openai_api",
499
    )
500
    parser.add_argument(
501
        "-m",
502
        "--model-worker",
503
        action="store_true",
504
        help="run fastchat's model_worker server with specified model name. "
505
             "specify --model-name if not using default LLM_MODELS",
506
        dest="model_worker",
507
    )
508
    parser.add_argument(
509
        "-n",
510
        "--model-name",
511
        type=str,
512
        nargs="+",
513
        default=LLM_MODELS,
514
        help="specify model name for model worker. "
515
             "add addition names with space seperated to start multiple model workers.",
516
        dest="model_name",
517
    )
518
    parser.add_argument(
519
        "-c",
520
        "--controller",
521
        type=str,
522
        help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
523
        dest="controller_address",
524
    )
525
    parser.add_argument(
526
        "--api",
527
        action="store_true",
528
        help="run api.py server",
529
        dest="api",
530
    )
531
    parser.add_argument(
532
        "-p",
533
        "--api-worker",
534
        action="store_true",
535
        help="run online model api such as zhipuai",
536
        dest="api_worker",
537
    )
538
    parser.add_argument(
539
        "-w",
540
        "--webui",
541
        action="store_true",
542
        help="run webui.py server",
543
        dest="webui",
544
    )
545
    parser.add_argument(
546
        "-q",
547
        "--quiet",
548
        action="store_true",
549
        help="减少fastchat服务log信息",
550
        dest="quiet",
551
    )
552
    parser.add_argument(
553
        "-i",
554
        "--lite",
555
        action="store_true",
556
        help="以Lite模式运行:仅支持在线API的LLM对话、搜索引擎对话",
557
        dest="lite",
558
    )
559
    args = parser.parse_args()
560
    return args, parser
561

562

563
def dump_server_info(after_start=False, args=None):
564
    import platform
565
    import langchain
566
    import fastchat
567
    from server.utils import api_address, webui_address
568

569
    print("\n")
570
    print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
571
    print(f"操作系统:{platform.platform()}.")
572
    print(f"python版本:{sys.version}")
573
    print(f"项目版本:{VERSION}")
574
    print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
575
    print("\n")
576

577
    models = LLM_MODELS
578
    if args and args.model_name:
579
        models = args.model_name
580

581
    print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
582
    print(f"当前启动的LLM模型:{models} @ {llm_device()}")
583

584
    for model in models:
585
        pprint(get_model_worker_config(model))
586
    print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
587

588
    if after_start:
589
        print("\n")
590
        print(f"服务端运行信息:")
591
        if args.openai_api:
592
            print(f"    OpenAI API Server: {fschat_openai_api_address()}")
593
        if args.api:
594
            print(f"    Chatchat  API  Server: {api_address()}")
595
        if args.webui:
596
            print(f"    Chatchat WEBUI Server: {webui_address()}")
597
    print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
598
    print("\n")
599

600

601
async def start_main_server():
602
    import time
603
    import signal
604

605
    def handler(signalname):
606
        """
607
        Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
608
        Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
609
        """
610

611
        def f(signal_received, frame):
612
            raise KeyboardInterrupt(f"{signalname} received")
613

614
        return f
615

616
    # This will be inherited by the child process if it is forked (not spawned)
617
    signal.signal(signal.SIGINT, handler("SIGINT"))
618
    signal.signal(signal.SIGTERM, handler("SIGTERM"))
619

620
    mp.set_start_method("spawn")
621
    manager = mp.Manager()
622
    run_mode = None
623

624
    queue = manager.Queue()
625
    args, parser = parse_args()
626

627
    if args.all_webui:
628
        args.openai_api = True
629
        args.model_worker = True
630
        args.api = True
631
        args.api_worker = True
632
        args.webui = True
633

634
    elif args.all_api:
635
        args.openai_api = True
636
        args.model_worker = True
637
        args.api = True
638
        args.api_worker = True
639
        args.webui = False
640

641
    elif args.llm_api:
642
        args.openai_api = True
643
        args.model_worker = True
644
        args.api_worker = True
645
        args.api = False
646
        args.webui = False
647

648
    if args.lite:
649
        args.model_worker = False
650
        run_mode = "lite"
651

652
    dump_server_info(args=args)
653

654
    if len(sys.argv) > 1:
655
        logger.info(f"正在启动服务:")
656
        logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
657

658
    processes = {"online_api": {}, "model_worker": {}}
659

660
    def process_count():
661
        return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
662

663
    if args.quiet or not log_verbose:
664
        log_level = "ERROR"
665
    else:
666
        log_level = "INFO"
667

668
    controller_started = manager.Event()
669
    if args.openai_api:
670
        process = Process(
671
            target=run_controller,
672
            name=f"controller",
673
            kwargs=dict(log_level=log_level, started_event=controller_started),
674
            daemon=True,
675
        )
676
        processes["controller"] = process
677

678
        process = Process(
679
            target=run_openai_api,
680
            name=f"openai_api",
681
            daemon=True,
682
        )
683
        processes["openai_api"] = process
684

685
    model_worker_started = []
686
    if args.model_worker:
687
        for model_name in args.model_name:
688
            config = get_model_worker_config(model_name)
689
            if not config.get("online_api"):
690
                e = manager.Event()
691
                model_worker_started.append(e)
692
                process = Process(
693
                    target=run_model_worker,
694
                    name=f"model_worker - {model_name}",
695
                    kwargs=dict(model_name=model_name,
696
                                controller_address=args.controller_address,
697
                                log_level=log_level,
698
                                q=queue,
699
                                started_event=e),
700
                    daemon=True,
701
                )
702
                processes["model_worker"][model_name] = process
703

704
    if args.api_worker:
705
        for model_name in args.model_name:
706
            config = get_model_worker_config(model_name)
707
            if (config.get("online_api")
708
                    and config.get("worker_class")
709
                    and model_name in FSCHAT_MODEL_WORKERS):
710
                e = manager.Event()
711
                model_worker_started.append(e)
712
                process = Process(
713
                    target=run_model_worker,
714
                    name=f"api_worker - {model_name}",
715
                    kwargs=dict(model_name=model_name,
716
                                controller_address=args.controller_address,
717
                                log_level=log_level,
718
                                q=queue,
719
                                started_event=e),
720
                    daemon=True,
721
                )
722
                processes["online_api"][model_name] = process
723

724
    api_started = manager.Event()
725
    if args.api:
726
        process = Process(
727
            target=run_api_server,
728
            name=f"API Server",
729
            kwargs=dict(started_event=api_started, run_mode=run_mode),
730
            daemon=True,
731
        )
732
        processes["api"] = process
733

734
    webui_started = manager.Event()
735
    if args.webui:
736
        process = Process(
737
            target=run_webui,
738
            name=f"WEBUI Server",
739
            kwargs=dict(started_event=webui_started, run_mode=run_mode),
740
            daemon=True,
741
        )
742
        processes["webui"] = process
743

744
    if process_count() == 0:
745
        parser.print_help()
746
    else:
747
        try:
748
            # 保证任务收到SIGINT后,能够正常退出
749
            if p := processes.get("controller"):
750
                p.start()
751
                p.name = f"{p.name} ({p.pid})"
752
                controller_started.wait()  # 等待controller启动完成
753

754
            if p := processes.get("openai_api"):
755
                p.start()
756
                p.name = f"{p.name} ({p.pid})"
757

758
            for n, p in processes.get("model_worker", {}).items():
759
                p.start()
760
                p.name = f"{p.name} ({p.pid})"
761

762
            for n, p in processes.get("online_api", []).items():
763
                p.start()
764
                p.name = f"{p.name} ({p.pid})"
765

766
            for e in model_worker_started:
767
                e.wait()
768

769
            if p := processes.get("api"):
770
                p.start()
771
                p.name = f"{p.name} ({p.pid})"
772
                api_started.wait()
773

774
            if p := processes.get("webui"):
775
                p.start()
776
                p.name = f"{p.name} ({p.pid})"
777
                webui_started.wait()
778

779
            dump_server_info(after_start=True, args=args)
780

781
            while True:
782
                cmd = queue.get()
783
                e = manager.Event()
784
                if isinstance(cmd, list):
785
                    model_name, cmd, new_model_name = cmd
786
                    if cmd == "start":  # 运行新模型
787
                        logger.info(f"准备启动新模型进程:{new_model_name}")
788
                        process = Process(
789
                            target=run_model_worker,
790
                            name=f"model_worker - {new_model_name}",
791
                            kwargs=dict(model_name=new_model_name,
792
                                        controller_address=args.controller_address,
793
                                        log_level=log_level,
794
                                        q=queue,
795
                                        started_event=e),
796
                            daemon=True,
797
                        )
798
                        process.start()
799
                        process.name = f"{process.name} ({process.pid})"
800
                        processes["model_worker"][new_model_name] = process
801
                        e.wait()
802
                        logger.info(f"成功启动新模型进程:{new_model_name}")
803
                    elif cmd == "stop":
804
                        if process := processes["model_worker"].get(model_name):
805
                            time.sleep(1)
806
                            process.terminate()
807
                            process.join()
808
                            logger.info(f"停止模型进程:{model_name}")
809
                        else:
810
                            logger.error(f"未找到模型进程:{model_name}")
811
                    elif cmd == "replace":
812
                        if process := processes["model_worker"].pop(model_name, None):
813
                            logger.info(f"停止模型进程:{model_name}")
814
                            start_time = datetime.now()
815
                            time.sleep(1)
816
                            process.terminate()
817
                            process.join()
818
                            process = Process(
819
                                target=run_model_worker,
820
                                name=f"model_worker - {new_model_name}",
821
                                kwargs=dict(model_name=new_model_name,
822
                                            controller_address=args.controller_address,
823
                                            log_level=log_level,
824
                                            q=queue,
825
                                            started_event=e),
826
                                daemon=True,
827
                            )
828
                            process.start()
829
                            process.name = f"{process.name} ({process.pid})"
830
                            processes["model_worker"][new_model_name] = process
831
                            e.wait()
832
                            timing = datetime.now() - start_time
833
                            logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。")
834
                        else:
835
                            logger.error(f"未找到模型进程:{model_name}")
836

837
            # for process in processes.get("model_worker", {}).values():
838
            #     process.join()
839
            # for process in processes.get("online_api", {}).values():
840
            #     process.join()
841

842
            # for name, process in processes.items():
843
            #     if name not in ["model_worker", "online_api"]:
844
            #         if isinstance(p, dict):
845
            #             for work_process in p.values():
846
            #                 work_process.join()
847
            #         else:
848
            #             process.join()
849
        except Exception as e:
850
            logger.error(e)
851
            logger.warning("Caught KeyboardInterrupt! Setting stop event...")
852
        finally:
853

854
            for p in processes.values():
855
                logger.warning("Sending SIGKILL to %s", p)
856
                # Queues and other inter-process communication primitives can break when
857
                # process is killed, but we don't care here
858

859
                if isinstance(p, dict):
860
                    for process in p.values():
861
                        process.kill()
862
                else:
863
                    p.kill()
864

865
            for p in processes.values():
866
                logger.info("Process status: %s", p)
867

868

869
if __name__ == "__main__":
870
    create_tables()
871
    if sys.version_info < (3, 10):
872
        loop = asyncio.get_event_loop()
873
    else:
874
        try:
875
            loop = asyncio.get_running_loop()
876
        except RuntimeError:
877
            loop = asyncio.new_event_loop()
878

879
        asyncio.set_event_loop(loop)
880

881
    loop.run_until_complete(start_main_server())
882

883
# 服务启动后接口调用示例:
884
# import openai
885
# openai.api_key = "EMPTY" # Not support yet
886
# openai.api_base = "http://localhost:8888/v1"
887

888
# model = "chatglm3-6b"
889

890
# # create a chat completion
891
# completion = openai.ChatCompletion.create(
892
#   model=model,
893
#   messages=[{"role": "user", "content": "Hello! What is your name?"}]
894
# )
895
# # print the completion
896
# print(completion.choices[0].message.content)
897

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.