2
Runs benchmark evaluations with the BEIR dataset.
4
Install txtai and the following dependencies to run:
5
pip install txtai pytrec_eval rank-bm25 elasticsearch psutil
21
from rank_bm25 import BM25Okapi
22
from pytrec_eval import RelevanceEvaluator
24
from elasticsearch import Elasticsearch
25
from elasticsearch.helpers import bulk
27
from txtai.embeddings import Embeddings
28
from txtai.pipeline import Extractor, LLM, Tokenizer
29
from txtai.scoring import ScoringFactory
34
Base index definition. Defines methods to index and search a dataset.
37
def __init__(self, path, config, output, refresh):
43
config: path to config file
44
output: path to store index
45
refresh: overwrites existing index if True, otherwise existing index is loaded
51
self.refresh = refresh
54
self.backend = self.index()
56
def __call__(self, limit, filterscores=True):
58
Main evaluation logic. Loads an index, runs the dataset queries and returns the results.
61
limit: maximum results
62
filterscores: if exact matches should be filtered out
68
uids, queries = self.load()
71
offset, results = 0, {}
72
for batch in self.batch(queries, 256):
73
for i, r in enumerate(self.search(batch, limit + 1)):
76
r = [(x["id"], x["score"]) for x in r] if r and isinstance(r[0], dict) else r
79
r = [(uid, score) for uid, score in r if uid != uids[offset + i]][:limit]
81
results[uids[offset + i]] = dict(r)
88
def search(self, queries, limit):
90
Runs a search for a set of queries.
93
queries: list of queries to run
94
limit: maximum results
100
return self.backend.batchsearch(queries, limit)
107
raise NotImplementedError
111
Iterates over the dataset yielding a row at a time for indexing.
114
with open(f"{self.path}/corpus.jsonl", encoding="utf-8") as f:
116
row = json.loads(line)
117
text = f'{row["title"]}. {row["text"]}' if row["title"] else row["text"]
119
yield (row["_id"], text, None)
123
Loads queries for the dataset. Returns a list of expected result ids and input queries.
126
(result ids, input queries)
129
with open(f"{self.path}/queries.jsonl", encoding="utf-8") as f:
130
data = [json.loads(query) for query in f]
131
uids, queries = [x["_id"] for x in data], [x["text"] for x in data]
135
def batch(self, data, size):
137
Splits data into equal sized batches.
144
data split into equal size batches
147
return [data[x : x + size] for x in range(0, len(data), size)]
149
def readconfig(self, key, default):
151
Reads configuration from a config file. Returns default configuration
152
if config file is not found or config key isn't present.
155
key: configuration key to lookup
156
default: default configuration
159
config if found, otherwise returns default config
162
if self.config and os.path.exists(self.config):
164
with open(self.config, "r", encoding="utf-8") as f:
166
config = yaml.safe_load(f)
175
BM25 index using txtai.
180
config = self.readconfig("scoring", {"method": "bm25", "terms": True})
183
scoring = ScoringFactory.create(config)
185
if os.path.exists(self.output) and not self.refresh:
186
scoring.load(self.output)
188
scoring.index(self.rows())
189
scoring.save(self.output)
196
Embeddings index using txtai.
200
if os.path.exists(self.output) and not self.refresh:
201
embeddings = Embeddings()
202
embeddings.load(self.output)
205
config = self.readconfig("embeddings", {"batch": 8192, "encodebatch": 128, "faiss": {"quantize": True, "sample": 0.05}})
208
embeddings = Embeddings(config)
209
embeddings.index(self.rows())
210
embeddings.save(self.output)
217
Hybrid embeddings + BM25 index using txtai.
221
if os.path.exists(self.output) and not self.refresh:
222
embeddings = Embeddings()
223
embeddings.load(self.output)
226
config = self.readconfig(
231
"faiss": {"quantize": True, "sample": 0.05},
232
"scoring": {"method": "bm25", "terms": True, "normalize": True},
237
embeddings = Embeddings(config)
238
embeddings.index(self.rows())
239
embeddings.save(self.output)
246
Retrieval augmented generation (RAG) using txtai.
249
def __init__(self, path, config, output, refresh):
251
super().__init__(path, config, output, refresh)
254
llm = self.readconfig("llm", {})
257
extractor = self.readconfig("extractor", {})
260
self.extractor = Extractor(self.backend, LLM(**llm), output="reference", **extractor)
262
def search(self, queries, limit):
264
self.extractor.context = limit
265
return [[(x["reference"], 1)] for x in self.extractor(queries, maxlength=4096)]
268
class RankBM25(Index):
270
BM25 index using rank-bm25.
273
def search(self, queries, limit):
274
ids, backend = self.backend
275
tokenizer, results = Tokenizer(), []
276
for query in queries:
277
scores = backend.get_scores(tokenizer(query))
278
topn = np.argsort(scores)[::-1][:limit]
279
results.append([(ids[x], scores[x]) for x in topn])
284
if os.path.exists(self.output) and not self.refresh:
285
with open(self.output, "rb") as f:
286
ids, model = pickle.load(f)
289
tokenizer, data = Tokenizer(), []
290
for uid, text, _ in self.rows():
291
data.append((uid, tokenizer(text)))
293
ids = [uid for uid, _ in data]
294
model = BM25Okapi([text for _, text in data])
299
class SQLiteFTS(Index):
301
BM25 index using SQLite's FTS extension.
304
def search(self, queries, limit):
305
tokenizer, results = Tokenizer(), []
306
for query in queries:
307
query = tokenizer(query)
308
query = " OR ".join([f'"{q}"' for q in query])
310
self.backend.execute(
311
f"SELECT id, bm25(textindex) * -1 score FROM textindex WHERE text MATCH ? ORDER BY bm25(textindex) LIMIT {limit}", [query]
314
results.append(list(self.backend))
319
if os.path.exists(self.output) and not self.refresh:
321
connection = sqlite3.connect(self.output)
324
if os.path.exists(self.output):
325
os.remove(self.output)
328
connection = sqlite3.connect(self.output)
331
tokenizer, data = Tokenizer(), []
332
for uid, text, _ in self.rows():
333
data.append((uid, " ".join(tokenizer(text))))
336
connection.execute("CREATE VIRTUAL TABLE textindex using fts5(id, text)")
339
connection.executemany("INSERT INTO textindex VALUES (?, ?)", data)
343
return connection.cursor()
348
BM25 index using Elasticsearch.
351
def search(self, queries, limit):
354
for query in queries:
355
req_head = {"index": "textindex", "search_type": "dfs_query_then_fetch"}
358
"query": {"multi_match": {"query": query, "type": "best_fields", "fields": ["text"], "tie_breaker": 0.5}},
361
request.extend([req_head, req_body])
364
response = self.backend.msearch(body=request, request_timeout=600)
368
for resp in response["responses"]:
369
result = resp["hits"]["hits"]
370
results.append([(r["_id"], r["_score"]) for r in result])
375
es = Elasticsearch("http://localhost:9200")
380
es.indices.delete(index="textindex")
384
bulk(es, ({"_index": "textindex", "_id": uid, "text": text} for uid, text, _ in self.rows()))
385
es.indices.refresh(index="textindex")
392
Loads relevance data for evaluation.
395
path: path to dataset test file
402
with open(f"{path}/qrels/test.tsv", encoding="utf-8") as f:
403
reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
407
queryid, corpusid, score = row[0], row[1], int(row[2])
408
if queryid not in rel:
409
rel[queryid] = {corpusid: score}
411
rel[queryid][corpusid] = score
416
def create(method, path, config, output, refresh):
421
method: indexing method
422
path: path to dataset
423
config: path to config file
424
output: path to store index
425
refresh: overwrites existing index if True, otherwise existing index is loaded
432
return Elastic(path, config, output, refresh)
433
if method == "hybrid":
434
return Hybrid(path, config, output, refresh)
436
return RAG(path, config, output, refresh)
437
if method == "scoring":
438
return Score(path, config, output, refresh)
439
if method == "sqlite":
440
return SQLiteFTS(path, config, output, refresh)
442
return RankBM25(path, config, output, refresh)
445
return Embed(path, config, output, refresh)
450
Computes metrics using the results from an evaluation run.
453
results: evaluation results
461
for metric in results[r]:
462
if metric not in metrics:
465
metrics[metric].append(results[r][metric])
467
return {metric: round(np.mean(values), 5) for metric, values in metrics.items()}
470
def evaluate(methods, path, args):
475
methods: list of indexing methods to test
476
path: path to dataset
477
args: command line arguments
480
{calculated performance metrics}
483
print(f"------ {os.path.basename(path)} ------")
490
evaluator = RelevanceEvaluator(relevance(path), {f"ndcg_cut.{topk}", f"map_cut.{topk}", f"recall.{topk}", f"P.{topk}"})
491
for method in methods:
494
performance[method] = stats
498
output = args.output if args.output else f"{path}/{method}"
499
index = create(method, path, args.config, output, args.refresh)
502
stats["index"] = round(time.time() - start, 2)
503
stats["memory"] = int(psutil.Process().memory_info().rss / (1024 * 1024))
504
stats["disk"] = int(sum(d.stat().st_size for d in os.scandir(output) if d.is_file()) / 1024) if os.path.isdir(output) else 0
506
print("INDEX TIME =", time.time() - start)
507
print(f"MEMORY USAGE = {stats['memory']} MB")
508
print(f"DISK USAGE = {stats['disk']} KB")
511
results = index(topk)
514
stats["search"] = round(time.time() - start, 2)
515
print("SEARCH TIME =", time.time() - start)
518
metrics = compute(evaluator.evaluate(results))
521
for stat in [f"ndcg_cut_{topk}", f"map_cut_{topk}", f"recall_{topk}", f"P_{topk}"]:
522
stats[stat] = metrics[stat]
525
print(f"------ {method} ------")
526
print(f"NDCG@{topk} =", metrics[f"ndcg_cut_{topk}"])
527
print(f"MAP@{topk} =", metrics[f"map_cut_{topk}"])
528
print(f"Recall@{topk} =", metrics[f"recall_{topk}"])
529
print(f"P@{topk} =", metrics[f"P_{topk}"])
537
Main benchmark execution method.
540
args: command line arguments
544
directory = args.directory if args.directory else "beir"
546
if args.sources and args.methods:
547
sources, methods = args.sources.split(","), args.methods.split(",")
566
methods = ["bm25", "embed", "es", "hybrid", "rank", "sqlite"]
570
with open("benchmarks.json", mode, encoding="utf-8") as f:
571
for source in sources:
573
results = evaluate(methods, f"{directory}/{source}", args)
576
for method, stats in results.items():
577
stats["source"] = source
578
stats["method"] = method
579
stats["name"] = args.name if args.name else method
585
if __name__ == "__main__":
587
parser = argparse.ArgumentParser(description="Benchmarks")
588
parser.add_argument("-c", "--config", help="path to config file", metavar="CONFIG")
589
parser.add_argument("-d", "--directory", help="root directory path with datasets", metavar="DIRECTORY")
590
parser.add_argument("-m", "--methods", help="comma separated list of methods", metavar="METHODS")
591
parser.add_argument("-n", "--name", help="name to assign to this run, defaults to method name", metavar="NAME")
592
parser.add_argument("-o", "--output", help="index output directory path", metavar="OUTPUT")
596
help="refreshes index if set, otherwise uses existing index if available",
599
parser.add_argument("-s", "--sources", help="comma separated list of data sources", metavar="SOURCES")
600
parser.add_argument("-t", "--topk", help="top k results to use for the evaluation", metavar="TOPK", type=int, default=10)
603
benchmarks(parser.parse_args())