3
from pathlib import Path
5
from fastrag.stores import PLAIDDocumentStore
7
logger = logging.getLogger(__name__)
9
if __name__ == "__main__":
10
parser = argparse.ArgumentParser("Create an index using PLAID engine as a backend")
11
parser.add_argument("--checkpoint", type=Path, required=True)
12
parser.add_argument("--collection", type=Path, required=True)
13
parser.add_argument("--index-save-path", type=Path, required=True)
14
parser.add_argument("--gpus", type=int, default=0)
15
parser.add_argument("--ranks", type=int, default=1)
16
parser.add_argument("--doc-max-length", type=int, default=120)
17
parser.add_argument("--query-max-length", type=int, default=60)
18
parser.add_argument("--kmeans-iterations", type=int, default=4)
19
parser.add_argument("--name", type=str, default="plaid_index")
20
parser.add_argument("--nbits", type=int, default=2)
22
args = parser.parse_args()
25
args.ranks = args.gpus
31
store = PLAIDDocumentStore(
32
index_path=f"{args.index_save_path}",
33
checkpoint_path=f"{args.checkpoint}",
34
collection_path=f"{args.collection}",
39
doc_maxlen=args.doc_max_length,
40
query_maxlen=args.query_max_length,
41
kmeans_niters=args.kmeans_iterations,