2
Baseball statistics application with txtai and Streamlit.
4
Install txtai and streamlit (>= 1.23) to run:
5
pip install txtai streamlit
18
from txtai.embeddings import Embeddings
23
Base stats class. Contains methods for loading, indexing and searching baseball stats.
28
Creates a new Stats instance.
32
self.columns = self.loadcolumns()
35
self.stats = self.load()
38
self.names = self.loadnames()
41
self.vectors, self.data, self.embeddings = self.index()
43
def loadcolumns(self):
45
Returns a list of data columns.
51
raise NotImplementedError
55
Loads and returns raw stats.
61
raise NotImplementedError
65
Primary metric column.
71
raise NotImplementedError
73
def vector(self, row):
75
Build a vector for input row.
84
raise NotImplementedError
88
Loads a name - player id dictionary.
91
{player name: player id}
96
rows = self.stats.sort_values(by=self.metric(), ascending=False)[["nameFirst", "nameLast", "playerID"]].drop_duplicates().reset_index()
97
for x, row in rows.iterrows():
99
key = f"{row['nameFirst']} {row['nameLast']}"
100
key += f" ({row['playerID']})" if key in names else ""
104
exponent = 2 if ((len(rows) - x) / len(rows)) >= 0.95 else 1
107
score = math.pow(len(self.stats[self.stats["playerID"] == row["playerID"]]), exponent)
110
names[key] = (row["playerID"], score)
116
Builds an embeddings index to stats data. Returns vectors, input data and embeddings index.
119
vectors, data, embeddings
123
vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
124
data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
126
embeddings = Embeddings(
128
"transform": self.transform,
132
embeddings.index((uid, vectors[uid], None) for uid in vectors)
134
return vectors, data, embeddings
136
def metrics(self, name):
138
Looks up a player's active years, best statistical year and key metrics.
144
active, best, metrics
147
if name in self.names:
149
stats = self.stats[self.stats["playerID"] == self.names[name][0]]
152
metrics = stats[["yearID", self.metric()]]
155
best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0])
158
return metrics["yearID"].tolist(), best, metrics
160
return range(1871, datetime.datetime.today().year), 1950, None
162
def search(self, name=None, year=None, row=None, limit=10):
164
Runs an embeddings search. This method takes either a player-year or stats row as input.
167
name: player name to search
169
row: row of stats to search
170
limit: max results to return
177
query = self.vector(row)
180
name = self.names.get(name)
181
query = f"{year}{name[0] if name else name}"
182
query = self.vectors.get(query)
184
results, ids = [], set()
185
if query is not None:
186
for uid, _ in self.embeddings.search(query, limit * 5):
188
if uid[4:] not in ids:
189
result = self.data[uid].copy()
190
result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
191
results.append(result)
194
if len(ids) >= limit:
199
def transform(self, row):
201
Transforms a stats row into a vector.
210
if isinstance(row, np.ndarray):
213
return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns])
221
def loadcolumns(self):
257
players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
258
batting = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Batting.csv")
259
fielding = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Fielding.csv")
262
batting = pd.merge(players, batting, how="inner", on=["playerID"])
265
batting = batting[((batting["AB"] + batting["BB"]) >= 350) & (batting["stint"] == 1)]
268
positions = self.positions(fielding)
271
batting["age"] = batting["yearID"] - batting["birthYear"]
272
batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1)
273
batting["AVG"] = batting["H"] / batting["AB"]
274
batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"])
275
batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"]
276
batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"]
277
batting["SLG"] = batting["TB"] / batting["AB"]
278
batting["OPS"] = batting["OBP"] + batting["SLG"]
279
batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100
286
def vector(self, row):
287
row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"]
288
row["AVG"] = row["H"] / row["AB"]
289
row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"])
290
row["SLG"] = row["TB"] / row["AB"]
291
row["OPS"] = row["OBP"] + row["SLG"]
292
row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100
294
return self.transform(row)
296
def positions(self, fielding):
298
Derives primary positions for players.
301
fielding: fielding data
304
{player id: (position, number of games)}
308
for _, row in fielding.iterrows():
309
uid = f'{row["yearID"]}{row["playerID"]}'
310
position = row["POS"] if row["POS"] else 0
313
elif position == "C":
315
elif position == "1B":
317
elif position == "2B":
319
elif position == "3B":
321
elif position == "SS":
323
elif position == "OF":
327
if uid not in positions or positions[uid][1] < row["G"]:
328
positions[uid] = (position, row["G"])
332
def position(self, positions, row):
334
Looks up primary position for player row.
337
positions: all player positions
341
primary player positions
344
uid = f'{row["yearID"]}{row["playerID"]}'
345
return positions[uid][0] if uid in positions else 0
348
class Pitching(Stats):
353
def loadcolumns(self):
391
players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
392
pitching = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Pitching.csv")
395
pitching = pd.merge(players, pitching, how="inner", on=["playerID"])
398
pitching = pitching[(pitching["G"] >= 20) & (pitching["stint"] == 1)]
401
pitching["age"] = pitching["yearID"] - pitching["birthYear"]
402
pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3)
403
pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"])
410
def vector(self, row):
411
row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None
412
row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None
414
return self.transform(row)
424
Creates a new application.
428
self.batting = Batting()
431
self.pitching = Pitching()
435
Runs a Streamlit application.
438
st.title("⚾ Baseball Statistics")
441
This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai).
442
Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project. Read [this
443
article](https://medium.com/neuml/explore-baseball-history-with-vector-search-5778d98d6846) for more details.
447
player, search = st.tabs(["Player", "Search"])
462
st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.")
465
params = self.params()
468
category = self.category(params.get("category"), "category")
469
stats = self.batting if category == "Batting" else self.pitching
472
name = self.name(stats.names, params.get("name"))
475
active, best, metrics = stats.metrics(name)
478
year = self.year(active, params.get("year"), best)
482
self.chart(category, metrics)
485
results = stats.search(name, year)
488
self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
491
st.experimental_set_query_params(category=category, name=name, year=year)
498
st.markdown("Find players with similar statistics.")
500
category = self.category("Batting", "searchcategory")
501
with st.form("search"):
502
if category == "Batting":
503
stats, columns = self.batting, self.batting.columns[:-6]
504
elif category == "Pitching":
505
stats, columns = self.pitching, self.pitching.columns[:-2]
508
inputs = st.data_editor(pd.DataFrame([dict((column, None) for column in columns)]), hide_index=True).astype(float)
510
submitted = st.form_submit_button("Search")
513
results = stats.search(row=inputs.to_dict(orient="records")[0])
516
self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
520
Get application parameters. This method combines URL parameters with session parameters.
527
params = st.experimental_get_query_params()
528
params = {x: params[x][0] for x in params}
531
if all(x in st.session_state for x in ["category", "name", "year"]):
533
params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None
536
params["category"] = st.session_state["category"]
537
params["name"] = st.session_state["name"]
541
def category(self, category, key):
543
Builds category input widget.
546
category: category parameter
554
categories = ["Batting", "Pitching"]
557
default = categories.index(category) if category and category in categories else 0
560
return st.radio("Stat", categories, index=default, horizontal=True, key=key)
562
def name(self, names, name):
564
Builds name input widget.
567
names: list of all allowable names
574
name = name if name and name in names else random.choices(list(names.keys()), weights=[names[x][1] for x in names])[0]
577
names = sorted(names)
580
return st.selectbox("Name", names, names.index(name), key="name")
582
def year(self, years, year, best):
584
Builds year input widget.
587
years: active years for a player
589
best: default to best year if year is invalid
596
year = int(year) if year and year.isdigit() and int(year) in years else best
599
return int(st.select_slider("Year", years, year, key="year") if len(years) > 1 else years[0])
601
def chart(self, category, metrics):
603
Displays a metric chart.
606
category: Batting or Pitching
607
metrics: player metrics to plot
611
metric = self.batting.metric() if category == "Batting" else self.pitching.metric()
614
metrics["yearID"] = metrics["yearID"].astype(str)
619
.mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75)
620
.encode(x=alt.X("yearID", title=""), y=alt.Y(metric, scale=alt.Scale(zero=False)))
624
rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})")
627
chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False)
630
st.altair_chart(chart + rule, theme="streamlit", use_container_width=True)
632
def table(self, results, columns):
634
Displays a list of results as a table.
637
results: list of results
638
columns: column names
644
column_order=columns,
646
"link": st.column_config.LinkColumn("Link", width="small"),
647
"yearID": st.column_config.NumberColumn("Year", format="%d"),
648
"nameFirst": "First",
657
st.write("Player-Year not found")
660
@st.cache_resource(show_spinner=False)
663
Creates and caches a Streamlit application.
672
if __name__ == "__main__":
673
os.environ["TOKENIZERS_PARALLELISM"] = "false"