llama-index
152 строки · 3.8 Кб
1import ast
2import copy
3from types import CodeType, ModuleType
4from typing import Any, Dict, Mapping, Sequence, Union
5
6ALLOWED_IMPORTS = {
7"math",
8"time",
9"datetime",
10"pandas",
11"scipy",
12"numpy",
13"matplotlib",
14"plotly",
15"seaborn",
16}
17
18
19def _restricted_import(
20name: str,
21globals: Union[Mapping[str, object], None] = None,
22locals: Union[Mapping[str, object], None] = None,
23fromlist: Sequence[str] = (),
24level: int = 0,
25) -> ModuleType:
26if name in ALLOWED_IMPORTS:
27return __import__(name, globals, locals, fromlist, level)
28raise ImportError(f"Import of module '{name}' is not allowed")
29
30
31ALLOWED_BUILTINS = {
32"abs": abs,
33"all": all,
34"any": any,
35"ascii": ascii,
36"bin": bin,
37"bool": bool,
38"bytearray": bytearray,
39"bytes": bytes,
40"chr": chr,
41"complex": complex,
42"divmod": divmod,
43"enumerate": enumerate,
44"filter": filter,
45"float": float,
46"format": format,
47"frozenset": frozenset,
48"getattr": getattr,
49"hasattr": hasattr,
50"hash": hash,
51"hex": hex,
52"int": int,
53"isinstance": isinstance,
54"issubclass": issubclass,
55"iter": iter,
56"len": len,
57"list": list,
58"map": map,
59"max": max,
60"min": min,
61"next": next,
62"oct": oct,
63"ord": ord,
64"pow": pow,
65"print": print,
66"range": range,
67"repr": repr,
68"reversed": reversed,
69"round": round,
70"set": set,
71"setattr": setattr,
72"slice": slice,
73"sorted": sorted,
74"str": str,
75"sum": sum,
76"tuple": tuple,
77"type": type,
78"zip": zip,
79# Constants
80"True": True,
81"False": False,
82"None": None,
83"__import__": _restricted_import,
84}
85
86
87def _get_restricted_globals(__globals: Union[dict, None]) -> Any:
88restricted_globals = copy.deepcopy(ALLOWED_BUILTINS)
89if __globals:
90restricted_globals.update(__globals)
91return restricted_globals
92
93
94class DunderVisitor(ast.NodeVisitor):
95def __init__(self) -> None:
96self.has_access_to_private_entity = False
97
98def visit_Name(self, node: ast.Name) -> None:
99if node.id.startswith("_"):
100self.has_access_to_private_entity = True
101self.generic_visit(node)
102
103def visit_Attribute(self, node: ast.Attribute) -> None:
104if node.attr.startswith("_"):
105self.has_access_to_private_entity = True
106self.generic_visit(node)
107
108
109def _contains_protected_access(code: str) -> bool:
110tree = ast.parse(code)
111dunder_visitor = DunderVisitor()
112dunder_visitor.visit(tree)
113return dunder_visitor.has_access_to_private_entity
114
115
116def _verify_source_safety(__source: Union[str, bytes, CodeType]) -> None:
117"""
118Verify that the source is safe to execute. For now, this means that it
119does not contain any references to private or dunder methods.
120"""
121if isinstance(__source, CodeType):
122raise RuntimeError("Direct execution of CodeType is forbidden!")
123if isinstance(__source, bytes):
124__source = __source.decode()
125if _contains_protected_access(__source):
126raise RuntimeError(
127"Execution of code containing references to private or dunder methods is forbidden!"
128)
129
130
131def safe_eval(
132__source: Union[str, bytes, CodeType],
133__globals: Union[Dict[str, Any], None] = None,
134__locals: Union[Mapping[str, object], None] = None,
135) -> Any:
136"""
137eval within safe global context.
138"""
139_verify_source_safety(__source)
140return eval(__source, _get_restricted_globals(__globals), __locals)
141
142
143def safe_exec(
144__source: Union[str, bytes, CodeType],
145__globals: Union[Dict[str, Any], None] = None,
146__locals: Union[Mapping[str, object], None] = None,
147) -> None:
148"""
149eval within safe global context.
150"""
151_verify_source_safety(__source)
152return exec(__source, _get_restricted_globals(__globals), __locals)
153