txtai

Форк
0
/
baseball.py 
677 строк · 19.0 Кб
1
"""
2
Baseball statistics application with txtai and Streamlit.
3

4
Install txtai and streamlit (>= 1.23) to run:
5
  pip install txtai streamlit
6
"""
7

8
import datetime
9
import math
10
import os
11
import random
12

13
import altair as alt
14
import numpy as np
15
import pandas as pd
16
import streamlit as st
17

18
from txtai.embeddings import Embeddings
19

20

21
class Stats:
22
    """
23
    Base stats class. Contains methods for loading, indexing and searching baseball stats.
24
    """
25

26
    def __init__(self):
27
        """
28
        Creates a new Stats instance.
29
        """
30

31
        # Load columns
32
        self.columns = self.loadcolumns()
33

34
        # Load stats data
35
        self.stats = self.load()
36

37
        # Load names
38
        self.names = self.loadnames()
39

40
        # Build index
41
        self.vectors, self.data, self.embeddings = self.index()
42

43
    def loadcolumns(self):
44
        """
45
        Returns a list of data columns.
46

47
        Returns:
48
            list of columns
49
        """
50

51
        raise NotImplementedError
52

53
    def load(self):
54
        """
55
        Loads and returns raw stats.
56

57
        Returns:
58
            stats
59
        """
60

61
        raise NotImplementedError
62

63
    def metric(self):
64
        """
65
        Primary metric column.
66

67
        Returns:
68
            metric column name
69
        """
70

71
        raise NotImplementedError
72

73
    def vector(self, row):
74
        """
75
        Build a vector for input row.
76

77
        Args:
78
            row: input row
79

80
        Returns:
81
            row vector
82
        """
83

84
        raise NotImplementedError
85

86
    def loadnames(self):
87
        """
88
        Loads a name - player id dictionary.
89

90
        Returns:
91
            {player name: player id}
92
        """
93

94
        # Get unique names
95
        names = {}
96
        rows = self.stats.sort_values(by=self.metric(), ascending=False)[["nameFirst", "nameLast", "playerID"]].drop_duplicates().reset_index()
97
        for x, row in rows.iterrows():
98
            # Name key
99
            key = f"{row['nameFirst']} {row['nameLast']}"
100
            key += f" ({row['playerID']})" if key in names else ""
101

102
            if key not in names:
103
                # Scale scores of top n players
104
                exponent = 2 if ((len(rows) - x) / len(rows)) >= 0.95 else 1
105

106
                # score = num seasons ^ exponent
107
                score = math.pow(len(self.stats[self.stats["playerID"] == row["playerID"]]), exponent)
108

109
                # Save name key - values pair
110
                names[key] = (row["playerID"], score)
111

112
        return names
113

114
    def index(self):
115
        """
116
        Builds an embeddings index to stats data. Returns vectors, input data and embeddings index.
117

118
        Returns:
119
            vectors, data, embeddings
120
        """
121

122
        # Build data dictionary
123
        vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
124
        data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
125

126
        embeddings = Embeddings(
127
            {
128
                "transform": self.transform,
129
            }
130
        )
131

132
        embeddings.index((uid, vectors[uid], None) for uid in vectors)
133

134
        return vectors, data, embeddings
135

136
    def metrics(self, name):
137
        """
138
        Looks up a player's active years, best statistical year and key metrics.
139

140
        Args:
141
            name: player name
142

143
        Returns:
144
            active, best, metrics
145
        """
146

147
        if name in self.names:
148
            # Get player stats
149
            stats = self.stats[self.stats["playerID"] == self.names[name][0]]
150

151
            # Build key metrics
152
            metrics = stats[["yearID", self.metric()]]
153

154
            # Get best year, sort by primary metric
155
            best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0])
156

157
            # Get years active, best year, along with metric trends
158
            return metrics["yearID"].tolist(), best, metrics
159

160
        return range(1871, datetime.datetime.today().year), 1950, None
161

162
    def search(self, name=None, year=None, row=None, limit=10):
163
        """
164
        Runs an embeddings search. This method takes either a player-year or stats row as input.
165

166
        Args:
167
            name: player name to search
168
            year: year to search
169
            row: row of stats to search
170
            limit: max results to return
171

172
        Returns:
173
            list of results
174
        """
175

176
        if row:
177
            query = self.vector(row)
178
        else:
179
            # Lookup player key and build vector id
180
            name = self.names.get(name)
181
            query = f"{year}{name[0] if name else name}"
182
            query = self.vectors.get(query)
183

184
        results, ids = [], set()
185
        if query is not None:
186
            for uid, _ in self.embeddings.search(query, limit * 5):
187
                # Only add unique players
188
                if uid[4:] not in ids:
189
                    result = self.data[uid].copy()
190
                    result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
191
                    results.append(result)
192
                    ids.add(uid[4:])
193

194
                    if len(ids) >= limit:
195
                        break
196

197
        return results
198

199
    def transform(self, row):
200
        """
201
        Transforms a stats row into a vector.
202

203
        Args:
204
            row: stats row
205

206
        Returns:
207
            vector
208
        """
209

210
        if isinstance(row, np.ndarray):
211
            return row
212

213
        return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns])
214

215

216
class Batting(Stats):
217
    """
218
    Batting stats.
219
    """
220

221
    def loadcolumns(self):
222
        return [
223
            "birthMonth",
224
            "yearID",
225
            "age",
226
            "height",
227
            "weight",
228
            "G",
229
            "AB",
230
            "R",
231
            "H",
232
            "1B",
233
            "2B",
234
            "3B",
235
            "HR",
236
            "RBI",
237
            "SB",
238
            "CS",
239
            "BB",
240
            "SO",
241
            "IBB",
242
            "HBP",
243
            "SH",
244
            "SF",
245
            "GIDP",
246
            "POS",
247
            "AVG",
248
            "OBP",
249
            "TB",
250
            "SLG",
251
            "OPS",
252
            "OPS+",
253
        ]
254

255
    def load(self):
256
        # Retrieve raw data from GitHub
257
        players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
258
        batting = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Batting.csv")
259
        fielding = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Fielding.csv")
260

261
        # Merge player data in
262
        batting = pd.merge(players, batting, how="inner", on=["playerID"])
263

264
        # Require player to have at least 350 plate appearances.
265
        batting = batting[((batting["AB"] + batting["BB"]) >= 350) & (batting["stint"] == 1)]
266

267
        # Derive primary player positions
268
        positions = self.positions(fielding)
269

270
        # Calculated columns
271
        batting["age"] = batting["yearID"] - batting["birthYear"]
272
        batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1)
273
        batting["AVG"] = batting["H"] / batting["AB"]
274
        batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"])
275
        batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"]
276
        batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"]
277
        batting["SLG"] = batting["TB"] / batting["AB"]
278
        batting["OPS"] = batting["OBP"] + batting["SLG"]
279
        batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100
280

281
        return batting
282

283
    def metric(self):
284
        return "OPS+"
285

286
    def vector(self, row):
287
        row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"]
288
        row["AVG"] = row["H"] / row["AB"]
289
        row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"])
290
        row["SLG"] = row["TB"] / row["AB"]
291
        row["OPS"] = row["OBP"] + row["SLG"]
292
        row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100
293

294
        return self.transform(row)
295

296
    def positions(self, fielding):
297
        """
298
        Derives primary positions for players.
299

300
        Args:
301
            fielding: fielding data
302

303
        Returns:
304
            {player id: (position, number of games)}
305
        """
306

307
        positions = {}
308
        for _, row in fielding.iterrows():
309
            uid = f'{row["yearID"]}{row["playerID"]}'
310
            position = row["POS"] if row["POS"] else 0
311
            if position == "P":
312
                position = 1
313
            elif position == "C":
314
                position = 2
315
            elif position == "1B":
316
                position = 3
317
            elif position == "2B":
318
                position = 4
319
            elif position == "3B":
320
                position = 5
321
            elif position == "SS":
322
                position = 6
323
            elif position == "OF":
324
                position = 7
325

326
            # Save position if not set or player played more at this position
327
            if uid not in positions or positions[uid][1] < row["G"]:
328
                positions[uid] = (position, row["G"])
329

330
        return positions
331

332
    def position(self, positions, row):
333
        """
334
        Looks up primary position for player row.
335

336
        Arg:
337
            positions: all player positions
338
            row: player row
339

340
        Returns:
341
            primary player positions
342
        """
343

344
        uid = f'{row["yearID"]}{row["playerID"]}'
345
        return positions[uid][0] if uid in positions else 0
346

347

348
class Pitching(Stats):
349
    """
350
    Pitching stats.
351
    """
352

353
    def loadcolumns(self):
354
        return [
355
            "birthMonth",
356
            "yearID",
357
            "age",
358
            "height",
359
            "weight",
360
            "W",
361
            "L",
362
            "G",
363
            "GS",
364
            "CG",
365
            "SHO",
366
            "SV",
367
            "IPouts",
368
            "H",
369
            "ER",
370
            "HR",
371
            "BB",
372
            "SO",
373
            "BAOpp",
374
            "ERA",
375
            "IBB",
376
            "WP",
377
            "HBP",
378
            "BK",
379
            "BFP",
380
            "GF",
381
            "R",
382
            "SH",
383
            "SF",
384
            "GIDP",
385
            "WHIP",
386
            "WADJ",
387
        ]
388

389
    def load(self):
390
        # Retrieve raw data from GitHub
391
        players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
392
        pitching = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Pitching.csv")
393

394
        # Merge player data in
395
        pitching = pd.merge(players, pitching, how="inner", on=["playerID"])
396

397
        # Require player to have 20 appearances
398
        pitching = pitching[(pitching["G"] >= 20) & (pitching["stint"] == 1)]
399

400
        # Calculated columns
401
        pitching["age"] = pitching["yearID"] - pitching["birthYear"]
402
        pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3)
403
        pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"])
404

405
        return pitching
406

407
    def metric(self):
408
        return "WADJ"
409

410
    def vector(self, row):
411
        row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None
412
        row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None
413

414
        return self.transform(row)
415

416

417
class Application:
418
    """
419
    Main application.
420
    """
421

422
    def __init__(self):
423
        """
424
        Creates a new application.
425
        """
426

427
        # Batting stats
428
        self.batting = Batting()
429

430
        # Pitching stats
431
        self.pitching = Pitching()
432

433
    def run(self):
434
        """
435
        Runs a Streamlit application.
436
        """
437

438
        st.title("⚾ Baseball Statistics")
439
        st.markdown(
440
            """
441
            This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai).
442
            Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project. Read [this
443
            article](https://medium.com/neuml/explore-baseball-history-with-vector-search-5778d98d6846) for more details.
444
        """
445
        )
446

447
        player, search = st.tabs(["Player", "Search"])
448

449
        # Player tab
450
        with player:
451
            self.player()
452

453
        # Search
454
        with search:
455
            self.search()
456

457
    def player(self):
458
        """
459
        Player tab.
460
        """
461

462
        st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.")
463

464
        # Get parameters
465
        params = self.params()
466

467
        # Category and stats
468
        category = self.category(params.get("category"), "category")
469
        stats = self.batting if category == "Batting" else self.pitching
470

471
        # Player name
472
        name = self.name(stats.names, params.get("name"))
473

474
        # Player metrics
475
        active, best, metrics = stats.metrics(name)
476

477
        # Player year
478
        year = self.year(active, params.get("year"), best)
479

480
        # Display metrics chart
481
        if len(active) > 1:
482
            self.chart(category, metrics)
483

484
        # Run search
485
        results = stats.search(name, year)
486

487
        # Display results
488
        self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
489

490
        # Save parameters
491
        st.experimental_set_query_params(category=category, name=name, year=year)
492

493
    def search(self):
494
        """
495
        Stats search tab.
496
        """
497

498
        st.markdown("Find players with similar statistics.")
499

500
        category = self.category("Batting", "searchcategory")
501
        with st.form("search"):
502
            if category == "Batting":
503
                stats, columns = self.batting, self.batting.columns[:-6]
504
            elif category == "Pitching":
505
                stats, columns = self.pitching, self.pitching.columns[:-2]
506

507
            # Enter stats with data editor
508
            inputs = st.data_editor(pd.DataFrame([dict((column, None) for column in columns)]), hide_index=True).astype(float)
509

510
            submitted = st.form_submit_button("Search")
511
            if submitted:
512
                # Run search
513
                results = stats.search(row=inputs.to_dict(orient="records")[0])
514

515
                # Display table
516
                self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
517

518
    def params(self):
519
        """
520
        Get application parameters. This method combines URL parameters with session parameters.
521

522
        Returns:
523
            parameters
524
        """
525

526
        # Get parameters
527
        params = st.experimental_get_query_params()
528
        params = {x: params[x][0] for x in params}
529

530
        # Sync parameters with session state
531
        if all(x in st.session_state for x in ["category", "name", "year"]):
532
            # Copy session year if category and name are unchanged
533
            params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None
534

535
            # Copy category and name from session state
536
            params["category"] = st.session_state["category"]
537
            params["name"] = st.session_state["name"]
538

539
        return params
540

541
    def category(self, category, key):
542
        """
543
        Builds category input widget.
544

545
        Args:
546
            category: category parameter
547
            key: widget key
548

549
        Returns:
550
            category component
551
        """
552

553
        # List of stat categories
554
        categories = ["Batting", "Pitching"]
555

556
        # Get category parameter, default if not available or valid
557
        default = categories.index(category) if category and category in categories else 0
558

559
        # Radio box component
560
        return st.radio("Stat", categories, index=default, horizontal=True, key=key)
561

562
    def name(self, names, name):
563
        """
564
        Builds name input widget.
565

566
        Args:
567
            names: list of all allowable names
568

569
        Returns:
570
            name component
571
        """
572

573
        # Get name parameter, default to random weighted value if not valid
574
        name = name if name and name in names else random.choices(list(names.keys()), weights=[names[x][1] for x in names])[0]
575

576
        # Sort names for display
577
        names = sorted(names)
578

579
        # Select box component
580
        return st.selectbox("Name", names, names.index(name), key="name")
581

582
    def year(self, years, year, best):
583
        """
584
        Builds year input widget.
585

586
        Args:
587
            years: active years for a player
588
            year: year parameter
589
            best: default to best year if year is invalid
590

591
        Returns:
592
            year component
593
        """
594

595
        # Get year parameter, default if not available or valid
596
        year = int(year) if year and year.isdigit() and int(year) in years else best
597

598
        # Slider component
599
        return int(st.select_slider("Year", years, year, key="year") if len(years) > 1 else years[0])
600

601
    def chart(self, category, metrics):
602
        """
603
        Displays a metric chart.
604

605
        Args:
606
            category: Batting or Pitching
607
            metrics: player metrics to plot
608
        """
609

610
        # Key metric
611
        metric = self.batting.metric() if category == "Batting" else self.pitching.metric()
612

613
        # Cast year to string
614
        metrics["yearID"] = metrics["yearID"].astype(str)
615

616
        # Metric over years
617
        chart = (
618
            alt.Chart(metrics)
619
            .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75)
620
            .encode(x=alt.X("yearID", title=""), y=alt.Y(metric, scale=alt.Scale(zero=False)))
621
        )
622

623
        # Create metric median rule line
624
        rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})")
625

626
        # Layered chart configuration
627
        chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False)
628

629
        # Draw chart
630
        st.altair_chart(chart + rule, theme="streamlit", use_container_width=True)
631

632
    def table(self, results, columns):
633
        """
634
        Displays a list of results as a table.
635

636
        Args:
637
            results: list of results
638
            columns: column names
639
        """
640

641
        if results:
642
            st.dataframe(
643
                results,
644
                column_order=columns,
645
                column_config={
646
                    "link": st.column_config.LinkColumn("Link", width="small"),
647
                    "yearID": st.column_config.NumberColumn("Year", format="%d"),
648
                    "nameFirst": "First",
649
                    "nameLast": "Last",
650
                    "teamID": "Team",
651
                    "age": "Age",
652
                    "weight": "Weight",
653
                    "height": "Height",
654
                },
655
            )
656
        else:
657
            st.write("Player-Year not found")
658

659

660
@st.cache_resource(show_spinner=False)
661
def create():
662
    """
663
    Creates and caches a Streamlit application.
664

665
    Returns:
666
        Application
667
    """
668

669
    return Application()
670

671

672
if __name__ == "__main__":
673
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
674

675
    # Create and run application
676
    app = create()
677
    app.run()
678

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.