summaryrefslogtreecommitdiff
path: root/run.py
diff options
context:
space:
mode:
Diffstat (limited to 'run.py')
-rw-r--r--run.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/run.py b/run.py
index 0d5d017..4ea5b3b 100644
--- a/run.py
+++ b/run.py
@@ -25,6 +25,7 @@ def parse_arguments():
parser.add_argument("--iteration_threshold", type=float, default=0.4, help="The threshold for iteration")
parser.add_argument("--passage_ratio", type=float, default=2, help="The ratio for passage")
parser.add_argument("--top_k_sentence", type=int, default=3, help="The top k sentence to use")
+ parser.add_argument("--use_vectorized_retrieval", action="store_true", help="Use vectorized matrix-based retrieval instead of BFS iteration")
return parser.parse_args()
@@ -59,7 +60,8 @@ def main():
max_iterations=args.max_iterations,
iteration_threshold=args.iteration_threshold,
passage_ratio=args.passage_ratio,
- top_k_sentence=args.top_k_sentence
+ top_k_sentence=args.top_k_sentence,
+ use_vectorized_retrieval=args.use_vectorized_retrieval
)
rag_model = LinearRAG(global_config=config)
rag_model.index(passages)