txtai

Форк
0
/
benchmarks.py 
603 строки · 16.9 Кб
1
"""
2
Runs benchmark evaluations with the BEIR dataset.
3

4
Install txtai and the following dependencies to run:
5
    pip install txtai pytrec_eval rank-bm25 elasticsearch psutil
6
"""
7

8
import argparse
9
import csv
10
import json
11
import os
12
import pickle
13
import sqlite3
14
import time
15

16
import psutil
17
import yaml
18

19
import numpy as np
20

21
from rank_bm25 import BM25Okapi
22
from pytrec_eval import RelevanceEvaluator
23

24
from elasticsearch import Elasticsearch
25
from elasticsearch.helpers import bulk
26

27
from txtai.embeddings import Embeddings
28
from txtai.pipeline import Extractor, LLM, Tokenizer
29
from txtai.scoring import ScoringFactory
30

31

32
class Index:
33
    """
34
    Base index definition. Defines methods to index and search a dataset.
35
    """
36

37
    def __init__(self, path, config, output, refresh):
38
        """
39
        Creates a new index.
40

41
        Args:
42
            path: path to dataset
43
            config: path to config file
44
            output: path to store index
45
            refresh: overwrites existing index if True, otherwise existing index is loaded
46
        """
47

48
        self.path = path
49
        self.config = config
50
        self.output = output
51
        self.refresh = refresh
52

53
        # Build and save index
54
        self.backend = self.index()
55

56
    def __call__(self, limit, filterscores=True):
57
        """
58
        Main evaluation logic. Loads an index, runs the dataset queries and returns the results.
59

60
        Args:
61
            limit: maximum results
62
            filterscores: if exact matches should be filtered out
63

64
        Returns:
65
            search results
66
        """
67

68
        uids, queries = self.load()
69

70
        # Run queries in batches
71
        offset, results = 0, {}
72
        for batch in self.batch(queries, 256):
73
            for i, r in enumerate(self.search(batch, limit + 1)):
74
                # Get result as list of (id, score) tuples
75
                r = list(r)
76
                r = [(x["id"], x["score"]) for x in r] if r and isinstance(r[0], dict) else r
77

78
                if filterscores:
79
                    r = [(uid, score) for uid, score in r if uid != uids[offset + i]][:limit]
80

81
                results[uids[offset + i]] = dict(r)
82

83
            # Increment offset
84
            offset += len(batch)
85

86
        return results
87

88
    def search(self, queries, limit):
89
        """
90
        Runs a search for a set of queries.
91

92
        Args:
93
            queries: list of queries to run
94
            limit: maximum results
95

96
        Returns:
97
            search results
98
        """
99

100
        return self.backend.batchsearch(queries, limit)
101

102
    def index(self):
103
        """
104
        Indexes a dataset.
105
        """
106

107
        raise NotImplementedError
108

109
    def rows(self):
110
        """
111
        Iterates over the dataset yielding a row at a time for indexing.
112
        """
113

114
        with open(f"{self.path}/corpus.jsonl", encoding="utf-8") as f:
115
            for line in f:
116
                row = json.loads(line)
117
                text = f'{row["title"]}. {row["text"]}' if row["title"] else row["text"]
118
                if text:
119
                    yield (row["_id"], text, None)
120

121
    def load(self):
122
        """
123
        Loads queries for the dataset. Returns a list of expected result ids and input queries.
124

125
        Returns:
126
            (result ids, input queries)
127
        """
128

129
        with open(f"{self.path}/queries.jsonl", encoding="utf-8") as f:
130
            data = [json.loads(query) for query in f]
131
            uids, queries = [x["_id"] for x in data], [x["text"] for x in data]
132

133
        return uids, queries
134

135
    def batch(self, data, size):
136
        """
137
        Splits data into equal sized batches.
138

139
        Args:
140
            data: input data
141
            size: batch size
142

143
        Returns:
144
            data split into equal size batches
145
        """
146

147
        return [data[x : x + size] for x in range(0, len(data), size)]
148

149
    def readconfig(self, key, default):
150
        """
151
        Reads configuration from a config file. Returns default configuration
152
        if config file is not found or config key isn't present.
153

154
        Args:
155
            key: configuration key to lookup
156
            default: default configuration
157

158
        Returns:
159
            config if found, otherwise returns default config
160
        """
161

162
        if self.config and os.path.exists(self.config):
163
            # Read configuration
164
            with open(self.config, "r", encoding="utf-8") as f:
165
                # Check for config
166
                config = yaml.safe_load(f)
167
                if key in config:
168
                    return config[key]
169

170
        return default
171

172

173
class Score(Index):
174
    """
175
    BM25 index using txtai.
176
    """
177

178
    def index(self):
179
        # Read configuration
180
        config = self.readconfig("scoring", {"method": "bm25", "terms": True})
181

182
        # Create scoring instance
183
        scoring = ScoringFactory.create(config)
184

185
        if os.path.exists(self.output) and not self.refresh:
186
            scoring.load(self.output)
187
        else:
188
            scoring.index(self.rows())
189
            scoring.save(self.output)
190

191
        return scoring
192

193

194
class Embed(Index):
195
    """
196
    Embeddings index using txtai.
197
    """
198

199
    def index(self):
200
        if os.path.exists(self.output) and not self.refresh:
201
            embeddings = Embeddings()
202
            embeddings.load(self.output)
203
        else:
204
            # Read configuration
205
            config = self.readconfig("embeddings", {"batch": 8192, "encodebatch": 128, "faiss": {"quantize": True, "sample": 0.05}})
206

207
            # Build index
208
            embeddings = Embeddings(config)
209
            embeddings.index(self.rows())
210
            embeddings.save(self.output)
211

212
        return embeddings
213

214

215
class Hybrid(Index):
216
    """
217
    Hybrid embeddings + BM25 index using txtai.
218
    """
219

220
    def index(self):
221
        if os.path.exists(self.output) and not self.refresh:
222
            embeddings = Embeddings()
223
            embeddings.load(self.output)
224
        else:
225
            # Read configuration
226
            config = self.readconfig(
227
                "hybrid",
228
                {
229
                    "batch": 8192,
230
                    "encodebatch": 128,
231
                    "faiss": {"quantize": True, "sample": 0.05},
232
                    "scoring": {"method": "bm25", "terms": True, "normalize": True},
233
                },
234
            )
235

236
            # Build index
237
            embeddings = Embeddings(config)
238
            embeddings.index(self.rows())
239
            embeddings.save(self.output)
240

241
        return embeddings
242

243

244
class RAG(Embed):
245
    """
246
    Retrieval augmented generation (RAG) using txtai.
247
    """
248

249
    def __init__(self, path, config, output, refresh):
250
        # Parent logic
251
        super().__init__(path, config, output, refresh)
252

253
        # Read LLM configuration
254
        llm = self.readconfig("llm", {})
255

256
        # Read Extractor configuration
257
        extractor = self.readconfig("extractor", {})
258

259
        # Load Extractor
260
        self.extractor = Extractor(self.backend, LLM(**llm), output="reference", **extractor)
261

262
    def search(self, queries, limit):
263
        # Set context window size to limit and run
264
        self.extractor.context = limit
265
        return [[(x["reference"], 1)] for x in self.extractor(queries, maxlength=4096)]
266

267

268
class RankBM25(Index):
269
    """
270
    BM25 index using rank-bm25.
271
    """
272

273
    def search(self, queries, limit):
274
        ids, backend = self.backend
275
        tokenizer, results = Tokenizer(), []
276
        for query in queries:
277
            scores = backend.get_scores(tokenizer(query))
278
            topn = np.argsort(scores)[::-1][:limit]
279
            results.append([(ids[x], scores[x]) for x in topn])
280

281
        return results
282

283
    def index(self):
284
        if os.path.exists(self.output) and not self.refresh:
285
            with open(self.output, "rb") as f:
286
                ids, model = pickle.load(f)
287
        else:
288
            # Tokenize data
289
            tokenizer, data = Tokenizer(), []
290
            for uid, text, _ in self.rows():
291
                data.append((uid, tokenizer(text)))
292

293
            ids = [uid for uid, _ in data]
294
            model = BM25Okapi([text for _, text in data])
295

296
        return ids, model
297

298

299
class SQLiteFTS(Index):
300
    """
301
    BM25 index using SQLite's FTS extension.
302
    """
303

304
    def search(self, queries, limit):
305
        tokenizer, results = Tokenizer(), []
306
        for query in queries:
307
            query = tokenizer(query)
308
            query = " OR ".join([f'"{q}"' for q in query])
309

310
            self.backend.execute(
311
                f"SELECT id, bm25(textindex) * -1 score FROM textindex WHERE text MATCH ? ORDER BY bm25(textindex) LIMIT {limit}", [query]
312
            )
313

314
            results.append(list(self.backend))
315

316
        return results
317

318
    def index(self):
319
        if os.path.exists(self.output) and not self.refresh:
320
            # Load existing database
321
            connection = sqlite3.connect(self.output)
322
        else:
323
            # Delete existing database
324
            if os.path.exists(self.output):
325
                os.remove(self.output)
326

327
            # Create new database
328
            connection = sqlite3.connect(self.output)
329

330
            # Tokenize data
331
            tokenizer, data = Tokenizer(), []
332
            for uid, text, _ in self.rows():
333
                data.append((uid, " ".join(tokenizer(text))))
334

335
            # Create table
336
            connection.execute("CREATE VIRTUAL TABLE textindex using fts5(id, text)")
337

338
            # Load data and build index
339
            connection.executemany("INSERT INTO textindex VALUES (?, ?)", data)
340

341
            connection.commit()
342

343
        return connection.cursor()
344

345

346
class Elastic(Index):
347
    """
348
    BM25 index using Elasticsearch.
349
    """
350

351
    def search(self, queries, limit):
352
        # Generate bulk queries
353
        request = []
354
        for query in queries:
355
            req_head = {"index": "textindex", "search_type": "dfs_query_then_fetch"}
356
            req_body = {
357
                "_source": False,
358
                "query": {"multi_match": {"query": query, "type": "best_fields", "fields": ["text"], "tie_breaker": 0.5}},
359
                "size": limit,
360
            }
361
            request.extend([req_head, req_body])
362

363
        # Run ES query
364
        response = self.backend.msearch(body=request, request_timeout=600)
365

366
        # Read responses
367
        results = []
368
        for resp in response["responses"]:
369
            result = resp["hits"]["hits"]
370
            results.append([(r["_id"], r["_score"]) for r in result])
371

372
        return results
373

374
    def index(self):
375
        es = Elasticsearch("http://localhost:9200")
376

377
        # Delete existing index
378
        # pylint: disable=W0702
379
        try:
380
            es.indices.delete(index="textindex")
381
        except:
382
            pass
383

384
        bulk(es, ({"_index": "textindex", "_id": uid, "text": text} for uid, text, _ in self.rows()))
385
        es.indices.refresh(index="textindex")
386

387
        return es
388

389

390
def relevance(path):
391
    """
392
    Loads relevance data for evaluation.
393

394
    Args:
395
        path: path to dataset test file
396

397
    Returns:
398
        relevance data
399
    """
400

401
    rel = {}
402
    with open(f"{path}/qrels/test.tsv", encoding="utf-8") as f:
403
        reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
404
        next(reader)
405

406
        for row in reader:
407
            queryid, corpusid, score = row[0], row[1], int(row[2])
408
            if queryid not in rel:
409
                rel[queryid] = {corpusid: score}
410
            else:
411
                rel[queryid][corpusid] = score
412

413
    return rel
414

415

416
def create(method, path, config, output, refresh):
417
    """
418
    Creates a new index.
419

420
    Args:
421
        method: indexing method
422
        path: path to dataset
423
        config: path to config file
424
        output: path to store index
425
        refresh: overwrites existing index if True, otherwise existing index is loaded
426

427
    Returns:
428
        Index
429
    """
430

431
    if method == "es":
432
        return Elastic(path, config, output, refresh)
433
    if method == "hybrid":
434
        return Hybrid(path, config, output, refresh)
435
    if method == "rag":
436
        return RAG(path, config, output, refresh)
437
    if method == "scoring":
438
        return Score(path, config, output, refresh)
439
    if method == "sqlite":
440
        return SQLiteFTS(path, config, output, refresh)
441
    if method == "rank":
442
        return RankBM25(path, config, output, refresh)
443

444
    # Default
445
    return Embed(path, config, output, refresh)
446

447

448
def compute(results):
449
    """
450
    Computes metrics using the results from an evaluation run.
451

452
    Args:
453
        results: evaluation results
454

455
    Returns:
456
        metrics
457
    """
458

459
    metrics = {}
460
    for r in results:
461
        for metric in results[r]:
462
            if metric not in metrics:
463
                metrics[metric] = []
464

465
            metrics[metric].append(results[r][metric])
466

467
    return {metric: round(np.mean(values), 5) for metric, values in metrics.items()}
468

469

470
def evaluate(methods, path, args):
471
    """
472
    Runs an evaluation.
473

474
    Args:
475
        methods: list of indexing methods to test
476
        path: path to dataset
477
        args: command line arguments
478

479
    Returns:
480
        {calculated performance metrics}
481
    """
482

483
    print(f"------ {os.path.basename(path)} ------")
484

485
    # Performance stats
486
    performance = {}
487

488
    # Calculate stats for each model type
489
    topk = args.topk
490
    evaluator = RelevanceEvaluator(relevance(path), {f"ndcg_cut.{topk}", f"map_cut.{topk}", f"recall.{topk}", f"P.{topk}"})
491
    for method in methods:
492
        # Stats for this source
493
        stats = {}
494
        performance[method] = stats
495

496
        # Create index and get results
497
        start = time.time()
498
        output = args.output if args.output else f"{path}/{method}"
499
        index = create(method, path, args.config, output, args.refresh)
500

501
        # Add indexing metrics
502
        stats["index"] = round(time.time() - start, 2)
503
        stats["memory"] = int(psutil.Process().memory_info().rss / (1024 * 1024))
504
        stats["disk"] = int(sum(d.stat().st_size for d in os.scandir(output) if d.is_file()) / 1024) if os.path.isdir(output) else 0
505

506
        print("INDEX TIME =", time.time() - start)
507
        print(f"MEMORY USAGE = {stats['memory']} MB")
508
        print(f"DISK USAGE = {stats['disk']} KB")
509

510
        start = time.time()
511
        results = index(topk)
512

513
        # Add search metrics
514
        stats["search"] = round(time.time() - start, 2)
515
        print("SEARCH TIME =", time.time() - start)
516

517
        # Calculate stats
518
        metrics = compute(evaluator.evaluate(results))
519

520
        # Add accuracy metrics
521
        for stat in [f"ndcg_cut_{topk}", f"map_cut_{topk}", f"recall_{topk}", f"P_{topk}"]:
522
            stats[stat] = metrics[stat]
523

524
        # Print model stats
525
        print(f"------ {method} ------")
526
        print(f"NDCG@{topk} =", metrics[f"ndcg_cut_{topk}"])
527
        print(f"MAP@{topk} =", metrics[f"map_cut_{topk}"])
528
        print(f"Recall@{topk} =", metrics[f"recall_{topk}"])
529
        print(f"P@{topk} =", metrics[f"P_{topk}"])
530

531
    print()
532
    return performance
533

534

535
def benchmarks(args):
536
    """
537
    Main benchmark execution method.
538

539
    Args:
540
        args: command line arguments
541
    """
542

543
    # Directory where BEIR datasets are stored
544
    directory = args.directory if args.directory else "beir"
545

546
    if args.sources and args.methods:
547
        sources, methods = args.sources.split(","), args.methods.split(",")
548
        mode = "a"
549
    else:
550
        # Default sources and methods
551
        sources = [
552
            "trec-covid",
553
            "nfcorpus",
554
            "nq",
555
            "hotpotqa",
556
            "fiqa",
557
            "arguana",
558
            "webis-touche2020",
559
            "quora",
560
            "dbpedia-entity",
561
            "scidocs",
562
            "fever",
563
            "climate-fever",
564
            "scifact",
565
        ]
566
        methods = ["bm25", "embed", "es", "hybrid", "rank", "sqlite"]
567
        mode = "w"
568

569
    # Run and save benchmarks
570
    with open("benchmarks.json", mode, encoding="utf-8") as f:
571
        for source in sources:
572
            # Run evaluations
573
            results = evaluate(methods, f"{directory}/{source}", args)
574

575
            # Save as JSON lines output
576
            for method, stats in results.items():
577
                stats["source"] = source
578
                stats["method"] = method
579
                stats["name"] = args.name if args.name else method
580

581
                json.dump(stats, f)
582
                f.write("\n")
583

584

585
if __name__ == "__main__":
586
    # Command line parser
587
    parser = argparse.ArgumentParser(description="Benchmarks")
588
    parser.add_argument("-c", "--config", help="path to config file", metavar="CONFIG")
589
    parser.add_argument("-d", "--directory", help="root directory path with datasets", metavar="DIRECTORY")
590
    parser.add_argument("-m", "--methods", help="comma separated list of methods", metavar="METHODS")
591
    parser.add_argument("-n", "--name", help="name to assign to this run, defaults to method name", metavar="NAME")
592
    parser.add_argument("-o", "--output", help="index output directory path", metavar="OUTPUT")
593
    parser.add_argument(
594
        "-r",
595
        "--refresh",
596
        help="refreshes index if set, otherwise uses existing index if available",
597
        action="store_true",
598
    )
599
    parser.add_argument("-s", "--sources", help="comma separated list of data sources", metavar="SOURCES")
600
    parser.add_argument("-t", "--topk", help="top k results to use for the evaluation", metavar="TOPK", type=int, default=10)
601

602
    # Calculate benchmarks
603
    benchmarks(parser.parse_args())
604

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.