21
from collections import defaultdict
22
from graphlib import TopologicalSorter
23
from inspect import getsource
24
from pathlib import Path
25
from types import ModuleType
29
from flask import current_app
30
from flask_appbuilder import Model
31
from flask_migrate import downgrade, upgrade
32
from progress.bar import ChargingBar
33
from sqlalchemy import create_engine, inspect
34
from sqlalchemy.ext.automap import automap_base
36
from superset import db
37
from superset.utils.mock_data import add_sample_rows
39
logger = logging.getLogger(__name__)
42
def import_migration_script(filepath: Path) -> ModuleType:
44
Import migration script as if it were a module.
46
spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
48
module = importlib.util.module_from_spec(spec)
49
spec.loader.exec_module(module)
51
raise Exception(f"No module spec found in location: `{str(filepath)}`")
54
def extract_modified_tables(module: ModuleType) -> set[str]:
56
Extract the tables being modified by a migration script.
58
This function uses a simple approach of looking at the source code of
59
the migration script looking for patterns. It could be improved by
60
actually traversing the AST.
63
tables: set[str] = set()
64
for function in {"upgrade", "downgrade"}:
65
source = getsource(getattr(module, function))
66
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
67
tables.update(re.findall(r'add_column\(\s*"(\w+?)"\s*,', source, re.DOTALL))
68
tables.update(re.findall(r'drop_column\(\s*"(\w+?)"\s*,', source, re.DOTALL))
73
def find_models(module: ModuleType) -> list[type[Model]]:
75
Find all models in a migration script.
77
models: list[type[Model]] = []
78
tables = extract_modified_tables(module)
81
queue = list(module.__dict__.values())
84
if hasattr(obj, "__tablename__"):
85
tables.add(obj.__tablename__)
86
elif isinstance(obj, list):
88
elif isinstance(obj, dict):
89
queue.extend(obj.values())
95
sqlalchemy_uri = current_app.config["SQLALCHEMY_DATABASE_URI"]
96
engine = create_engine(sqlalchemy_uri)
98
Base.prepare(engine, reflect=True)
104
model = getattr(Base.classes, table)
105
except AttributeError:
107
model.__tablename__ = table
111
inspector = inspect(model)
112
for column in inspector.columns.values():
113
for foreign_key in column.foreign_keys:
114
table = foreign_key.column.table.name
115
if table not in seen:
121
sorter: TopologicalSorter[Any] = TopologicalSorter()
123
inspector = inspect(model)
124
dependent_tables: list[str] = []
125
for column in inspector.columns.values():
126
for foreign_key in column.foreign_keys:
127
if foreign_key.column.table.name != model.__tablename__:
128
dependent_tables.append(foreign_key.column.table.name)
129
sorter.add(model.__tablename__, *dependent_tables)
130
order = list(sorter.static_order())
131
models.sort(key=lambda model: order.index(model.__tablename__))
137
@click.argument("filepath")
138
@click.option("--limit", default=1000, help="Maximum number of entities.")
139
@click.option("--force", is_flag=True, help="Do not prompt for confirmation.")
140
@click.option("--no-auto-cleanup", is_flag=True, help="Do not remove created models.")
142
filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False
144
auto_cleanup = not no_auto_cleanup
145
print(f"Importing migration script: {filepath}")
146
module = import_migration_script(Path(filepath))
148
revision: str = getattr(module, "revision", "")
149
down_revision: str = getattr(module, "down_revision", "")
150
if not revision or not down_revision:
152
"Not a valid migration script, couldn't find down_revision/revision"
155
print(f"Migration goes from {down_revision} to {revision}")
156
current_revision = db.engine.execute(
157
"SELECT version_num FROM alembic_version"
159
print(f"Current version of the DB is {current_revision}")
161
if current_revision != down_revision:
164
"\nRunning benchmark will downgrade the Superset DB to "
165
f"{down_revision} and upgrade to {revision} again. There may "
166
"be data loss in downgrades. Continue?",
169
downgrade(revision=down_revision)
171
print("\nIdentifying models used in the migration:")
172
models = find_models(module)
173
model_rows: dict[type[Model], int] = {}
175
rows = db.session.query(model).count()
176
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
177
model_rows[model] = rows
179
print("Benchmarking migration")
180
results: dict[str, float] = {}
182
upgrade(revision=revision)
183
duration = time.time() - start
184
results["Current"] = duration
185
print(f"Migration on current DB took: {duration:.2f} seconds")
188
new_models: dict[type[Model], list[Model]] = defaultdict(list)
189
while min_entities <= limit:
190
downgrade(revision=down_revision)
191
print(f"Running with at least {min_entities} entities of each model")
193
missing = min_entities - model_rows[model]
195
entities: list[Model] = []
196
print(f"- Adding {missing} entities to the {model.__name__} model")
197
bar = ChargingBar("Processing", max=missing)
199
for entity in add_sample_rows(model, missing):
200
entities.append(entity)
203
db.session.rollback()
206
model_rows[model] = min_entities
207
db.session.add_all(entities)
211
new_models[model].extend(entities)
213
upgrade(revision=revision)
214
duration = time.time() - start
215
print(f"Migration for {min_entities}+ entities took: {duration:.2f} seconds")
216
results[f"{min_entities}+"] = duration
219
print("\nResults:\n")
220
for label, duration in results.items():
221
print(f"{label}: {duration:.2f} s")
224
print("Cleaning up DB")
226
for model, entities in list(new_models.items())[::-1]:
227
db.session.query(model).filter(
228
model.id.in_(entity.id for entity in entities)
229
).delete(synchronize_session=False)
232
if current_revision != revision and not force:
233
click.confirm(f"\nRevert DB to {revision}?", abort=True)
234
upgrade(revision=revision)
238
if __name__ == "__main__":
239
from superset.app import create_app
242
with app.app_context():