1
from __future__ import annotations
7
from collections.abc import Callable
11
from fastapi import FastAPI
12
from starlette.middleware.base import RequestResponseEndpoint
13
from starlette.requests import Request
14
from starlette.responses import Response
15
from starlette.routing import Match, Mount
16
from starlette.types import Scope
18
from core.db_config import config
20
TIMER_ATTRIBUTE = "__fastapi_restful_timer__"
23
def add_timing_middleware(
24
app: FastAPI, record: Callable[[str], None] | None = None, prefix: str = "", exclude: str | None = None
27
Adds a middleware to the provided `app` that records timing metrics using the provided `record` callable.
29
Typically `record` would be something like `logger.info` for a `logging.Logger` instance.
31
The provided `prefix` is used when generating route names.
33
If `exclude` is provided, timings for any routes containing `exclude`
34
as an exact substring of the generated metric name will not be logged.
35
This provides an easy way to disable logging for routes
37
The `exclude` will probably be replaced by a regex match at some point in the future. (PR welcome!)
39
metric_namer = _MetricNamer(prefix=prefix, app=app)
41
@app.middleware("http")
42
async def timing_middleware(request: Request, call_next: RequestResponseEndpoint) -> Response:
43
metric_name = metric_namer(request.scope)
44
with _TimingStats(metric_name, record=record, exclude=exclude) as timer:
45
setattr(request.state, TIMER_ATTRIBUTE, timer)
46
response = await call_next(request)
50
def record_timing(request: Request, note: str | None = None) -> None:
52
Call this function at any point that you want to display elapsed time during the handling of a single request
54
This can help profile which piece of a request is causing a performance bottleneck.
56
Note that for this function to succeed, the request should have been generated by a FastAPI app
57
that has had timing middleware added using the `fastapi_restful.timing.add_timing_middleware` function.
59
timer = getattr(request.state, TIMER_ATTRIBUTE, None)
61
if not isinstance(timer, _TimingStats):
62
raise ValueError("Timer should be of an instance of TimingStats")
65
raise ValueError("No timer present on request")
70
This class tracks and records endpoint timing data.
72
Should be used as a context manager; on exit, timing stats will be emitted.
75
The name to include with the recorded timing data
77
The callable to call on generated messages. Defaults to `print`, but typically
78
something like `logger.info` for a `logging.Logger` instance would be preferable.
80
An optional string; if it is not None and occurs inside `name`, no stats will be emitted
84
self, name: str | None = None, record: Callable[[str], None] | None = None, exclude: str | None = None
87
self.record = record or print
89
self.process: psutil.Process = psutil.Process(os.getpid())
90
self.start_time: float = 0
91
self.start_cpu_time: float = 0
92
self.end_cpu_time: float = 0
93
self.end_time: float = 0
94
self.silent: bool = False
96
if self.name is not None and exclude is not None and (exclude in self.name):
99
def start(self) -> None:
100
self.start_time = time.time()
101
self.start_cpu_time = self._get_cpu_time()
103
def take_split(self) -> None:
104
self.end_time = time.time()
105
self.end_cpu_time = self._get_cpu_time()
108
def time(self) -> float:
109
return self.end_time - self.start_time
112
def cpu_time(self) -> float:
113
return self.end_cpu_time - self.start_cpu_time
115
def __enter__(self) -> _TimingStats:
119
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
122
def emit(self, note: str | None = None) -> None:
124
Emit timing information, optionally including a specified note
128
cpu_ms = 1000 * self.cpu_time
129
wall_ms = 1000 * self.time
130
message = f"TIMING: Wall: {wall_ms:6.1f}ms | CPU: {cpu_ms:6.1f}ms | {self.name}"
132
message += f" ({note})"
135
def _get_cpu_time(self) -> float:
137
Generates the cpu time to report. Adds the user and system time, following the implementation from timing-asgi
139
resources = self.process.cpu_times()
141
return resources[0] + resources[1]
146
This class generates the route "name" used when logging timing records.
148
If the route has `endpoint` and `name` attributes, the endpoint's module and route's name will be used
149
(along with an optional prefix that can be used, e.g., to distinguish between multiple mounted ASGI apps).
151
By default, in FastAPI the route name is the `__name__` of the route's function (or type if it is a callable class
154
For example, with prefix == "custom", a function defined in the module `app.crud` with name `read_item`
155
would get name `custom.app.crud.read_item`. If the empty string were used as the prefix, the result would be
156
just "app.crud.read_item".
158
For starlette.routing.Mount instances, the name of the type of `route.app` is used in a slightly different format.
160
For other routes missing either an endpoint or name, the raw route path is included in the generated name.
163
def __init__(self, prefix: str, app: FastAPI):
169
def __call__(self, scope: Scope) -> str:
171
Generates the actual name to use when logging timing metrics for a specified ASGI Scope
174
for r in self.app.router.routes:
175
if r.matches(scope)[0] == Match.FULL:
178
if hasattr(route, "endpoint") and hasattr(route, "name"):
179
name = f"{self.prefix}{route.endpoint.__module__}.{route.name}"
180
elif isinstance(route, Mount):
181
name = f"{type(route.app).__name__}<{route.name!r}>"
183
name = str(f"<Path: {scope['path']}>")
186
def timed(fn: Callable[..., Any]) -> Callable[..., Any]:
188
Decorator log test start and end time of a function
189
:param fn: Function to decorate
190
:return: Decorated function
194
def wrapped_fn(*args: Any, **kwargs: Any) -> Any:
196
logging.info(f"Running: {fn.__name__}")
197
ret = fn(*args, **kwargs)
198
duration_str = get_duration_str(start)
199
logging.info(f"Finished: {fn.__name__} in {duration_str}")
202
async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any:
204
logging.info(f"Running: {fn.__name__}")
205
ret = await fn(*args, **kwargs)
206
duration_str = get_duration_str(start)
207
logging.info(f"Finished: {fn.__name__} in {duration_str}")
209
if config.ENV in ('local', 'dev'):
210
if asyncio.iscoroutinefunction(fn):
211
return wrapped_fn_async
216
def get_duration_str(start: float) -> str:
217
"""Get human readable duration string from start time"""
218
duration = time.time() - start
220
duration_str = f'{duration:,.3f}s'
221
elif duration > 1e-3:
222
duration_str = f'{round(duration * 1e3)}ms'
223
elif duration > 1e-6:
224
duration_str = f'{round(duration * 1e6)}us'
226
duration_str = f'{duration * 1e9}ns'