summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--run.py26
1 files changed, 19 insertions, 7 deletions
diff --git a/run.py b/run.py
index b1f21f2..0d5d017 100644
--- a/run.py
+++ b/run.py
@@ -9,16 +9,22 @@ import warnings
from src.evaluate import Evaluator
from src.utils import LLM_Model
from src.utils import setup_logging
+from datetime import datetime
+os.environ["CUDA_VISIBLE_DEVICES"] = "4"
warnings.filterwarnings('ignore')
def parse_arguments():
parser = argparse.ArgumentParser()
- parser.add_argument("--spacy_model", type=str, default="en_core_sci_scibert", help="The spacy model to use")
+ parser.add_argument("--spacy_model", type=str, default="en_core_web_trf", help="The spacy model to use")
parser.add_argument("--embedding_model", type=str, default="model/all-mpnet-base-v2", help="The path of embedding model to use")
- parser.add_argument("--dataset_name", type=str, default="medical", help="The dataset to use")
+ parser.add_argument("--dataset_name", type=str, default="novel", help="The dataset to use")
parser.add_argument("--llm_model", type=str, default="gpt-4o-mini", help="The LLM model to use")
parser.add_argument("--max_workers", type=int, default=16, help="The max number of workers to use")
+ parser.add_argument("--max_iterations", type=int, default=3, help="The max number of iterations to use")
+ parser.add_argument("--iteration_threshold", type=float, default=0.4, help="The threshold for iteration")
+ parser.add_argument("--passage_ratio", type=float, default=2, help="The ratio for passage")
+ parser.add_argument("--top_k_sentence", type=int, default=3, help="The top k sentence to use")
return parser.parse_args()
@@ -37,25 +43,31 @@ def load_embedding_model(embedding_model):
return embedding_model
def main():
+ time = datetime.now()
+ time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
args = parse_arguments()
embedding_model = load_embedding_model(args.embedding_model)
questions,passages = load_dataset(args.dataset_name)
- setup_logging(f"results/{args.dataset_name}/log.txt")
+ setup_logging(f"results/{args.dataset_name}/{time_str}/log.txt")
llm_model = LLM_Model(args.llm_model)
config = LinearRAGConfig(
dataset_name=args.dataset_name,
embedding_model=embedding_model,
spacy_model=args.spacy_model,
max_workers=args.max_workers,
- llm_model=llm_model
+ llm_model=llm_model,
+ max_iterations=args.max_iterations,
+ iteration_threshold=args.iteration_threshold,
+ passage_ratio=args.passage_ratio,
+ top_k_sentence=args.top_k_sentence
)
rag_model = LinearRAG(global_config=config)
rag_model.index(passages)
questions = rag_model.qa(questions)
- os.makedirs(f"results/{args.dataset_name}", exist_ok=True)
- with open(f"results/{args.dataset_name}/predictions.json", "w", encoding="utf-8") as f:
+ os.makedirs(f"results/{args.dataset_name}/{time_str}", exist_ok=True)
+ with open(f"results/{args.dataset_name}/{time_str}/predictions.json", "w", encoding="utf-8") as f:
json.dump(questions, f, ensure_ascii=False, indent=4)
- evaluator = Evaluator(llm_model=llm_model, predictions_path=f"results/{args.dataset_name}/predictions.json")
+ evaluator = Evaluator(llm_model=llm_model, predictions_path=f"results/{args.dataset_name}/{time_str}/predictions.json")
evaluator.evaluate(max_workers=args.max_workers)
if __name__ == "__main__":
main() \ No newline at end of file