3
from pathlib import Path
7
from fastrag.utils import init_cls, init_haystack_cls, load_yaml
9
logger = logging.getLogger(__name__)
11
if __name__ == "__main__":
12
parser = argparse.ArgumentParser("Embed data and save to pickled file.")
13
parser.add_argument("--data", type=Path, required=True)
14
parser.add_argument("--embedder", type=Path, required=True)
15
parser.add_argument("--store", type=Path, required=True)
16
parser.add_argument("--batch_num", type=int, required=False)
17
parser.add_argument("--batch_size", type=int, required=False)
19
args = parser.parse_args()
21
store_params = load_yaml(args.store)
22
store_cls = store_params.pop("type")
23
store = init_haystack_cls(store_cls, store_params)
24
logger.info("Loaded store backend")
26
data_params = load_yaml(args.data)
27
emb_params = load_yaml(args.embedder)
29
data_cls = data_params.pop("type")
30
data = init_cls(data_cls, data_params)
31
logger.info("Done loading dataset")
33
logger.info("Loading Embedder")
34
emb_cls = emb_params.pop("type")
35
emb = init_haystack_cls(emb_cls, emb_params)
38
batch_start = args.batch_num or 0
40
logger.info("Creating Embeddings...")
43
list(range(data.chunks)),
46
if batch_i >= batch_start:
48
docs = data.process(batch_i)
49
emb_batch = emb.embed_documents(docs)
50
for d, e in zip(docs, emb_batch):
52
batch.append(d.to_dict())
53
store.write_documents(batch, batch_size=args.batch_size or 100)