ray-llm
44 строки · 1.0 Кб
1import contextvars
2import weakref
3from contextlib import contextmanager
4
5from fastapi.datastructures import State
6
7from rayllm.backend.observability.tracing import baggage
8
9# Fast api state context.
10# This may include secrets
11_fastapi_state_context: contextvars.ContextVar[
12weakref.ReferenceType[State]
13] = contextvars.ContextVar("aviary_fastapi_state")
14
15
16def set(**kwargs):
17return baggage.baggage(kwargs)
18
19
20def get(key: str):
21return baggage.get(key)
22
23
24@contextmanager
25def set_fastapi_state(request_state: State):
26# Hold a weakref to make sure that we clean up any state once Fastapi request object is garbage collected.
27ref = weakref.ref(request_state)
28token = _fastapi_state_context.set(ref)
29try:
30yield
31finally:
32_fastapi_state_context.reset(token)
33
34
35def get_fastapi_state():
36ctx = _fastapi_state_context.get(None)
37if ctx:
38return ctx()
39
40
41def maybe_get_string_field(field: str):
42state = get_fastapi_state()
43val = getattr(state, field, None)
44return val if isinstance(val, str) else None
45