1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements. See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership. The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License. You may obtain a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied. See the License for the
15
# specific language governing permissions and limitations
18
This module contains utilities to auto-generate an
19
Entity-Relationship Diagram (ERD) from SQLAlchemy
20
and onto a plantuml file.
25
from collections import defaultdict
26
from collections.abc import Iterable
27
from typing import Any, Optional
32
from superset import db
34
GROUPINGS: dict[str, Iterable[str]] = {
42
"embedded_dashboards",
48
"System": ["ssh_tunnels", "keyvalue", "cache_keys", "key_value", "logs"],
49
"Alerts & Reports": ["report_recipient", "report_execution_log", "report_schedule"],
50
"Inherited from Flask App Builder (FAB)": [
58
"SQL Lab": ["query", "saved_query", "tab_state", "table_schema"],
64
"row_level_security_filters",
68
"database_user_oauth2_tokens",
71
# Table name to group name mapping (reversing the above one for easy lookup)
72
TABLE_TO_GROUP_MAP: dict[str, str] = {}
73
for group, tables in GROUPINGS.items():
75
TABLE_TO_GROUP_MAP[table] = group
78
def sort_data_structure(data): # type: ignore
79
sorted_json = json.dumps(data, sort_keys=True)
80
sorted_data = json.loads(sorted_json)
84
def introspect_sqla_model(mapper: Any, seen: set[str]) -> dict[str, Any]:
86
Introspects a SQLAlchemy model and returns a data structure that
87
can be pass to a jinja2 template for instance
91
mapper: SQLAlchemy model mapper
92
seen: set of model identifiers to avoid duplicates
96
Dict[str, Any]: data structure for jinja2 template
98
table_name = mapper.persist_selectable.name
99
model_info: dict[str, Any] = {
100
"class_name": mapper.class_.__name__,
101
"table_name": table_name,
105
# Collect fields (columns) and their types
106
for column in mapper.columns:
107
field_info: dict[str, str] = {
108
"field_name": column.key,
109
"type": str(column.type),
111
model_info["fields"].append(field_info)
113
# Collect relationships and identify types
114
for attr, relationship in mapper.relationships.items():
115
related_table = relationship.mapper.persist_selectable.name
116
# Create a unique identifier for the relationship to avoid duplicates
117
relationship_id = "-".join(sorted([table_name, related_table]))
119
if relationship_id not in seen:
120
seen.add(relationship_id)
122
if relationship.direction.name == "MANYTOONE":
125
relationship_info: dict[str, str] = {
126
"relationship_name": attr,
127
"related_model": relationship.mapper.class_.__name__,
128
"type": relationship.direction.name,
129
"related_table": related_table,
131
# Identify many-to-many by checking for secondary table
132
if relationship.secondary is not None:
134
relationship_info["type"] = "many-to-many"
135
relationship_info["secondary_table"] = relationship.secondary.name
137
relationship_info["squiggle"] = squiggle
138
model_info["relationships"].append(relationship_info)
139
return sort_data_structure(model_info) # type: ignore
142
def introspect_models() -> dict[str, list[dict[str, Any]]]:
144
Introspects SQLAlchemy models and returns a data structure that
145
can be pass to a jinja2 template for rendering an ERD.
149
Dict[str, List[Dict[str, Any]]]: data structure for jinja2 template
151
data: dict[str, list[dict[str, Any]]] = defaultdict(list)
152
seen_models: set[str] = set()
153
for model in db.Model.registry.mappers:
155
TABLE_TO_GROUP_MAP.get(model.mapper.persist_selectable.name)
156
or "Uncategorized Models"
158
model_data = introspect_sqla_model(model, seen_models)
159
data[group_name].append(model_data)
163
def generate_erd(file_path: str) -> None:
165
Generates a PlantUML ERD of the models/database
170
File path to write the ERD to
172
data = introspect_models()
173
templates_path = os.path.dirname(__file__)
174
env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_path))
177
template = env.get_template("erd.template.puml")
178
rendered = template.render(data=data)
179
with open(file_path, "w") as f:
180
click.secho(f"Writing to {file_path}...", fg="green")
188
type=click.Path(dir_okay=False, writable=True),
189
help="File to write the ERD to",
191
def erd(output: Optional[str] = None) -> None:
193
Generates a PlantUML ERD of the models/database
197
output: str, optional
198
File to write the ERD to, defaults to erd.plantuml if not provided
200
path = os.path.dirname(__file__)
201
output = output or os.path.join(path, "erd.puml")
203
from superset.app import create_app
206
with app.app_context():
210
if __name__ == "__main__":