diff options
| -rw-r--r-- | run.py | 26 |
1 files changed, 19 insertions, 7 deletions
@@ -9,16 +9,22 @@ import warnings from src.evaluate import Evaluator from src.utils import LLM_Model from src.utils import setup_logging +from datetime import datetime +os.environ["CUDA_VISIBLE_DEVICES"] = "4" warnings.filterwarnings('ignore') def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument("--spacy_model", type=str, default="en_core_sci_scibert", help="The spacy model to use") + parser.add_argument("--spacy_model", type=str, default="en_core_web_trf", help="The spacy model to use") parser.add_argument("--embedding_model", type=str, default="model/all-mpnet-base-v2", help="The path of embedding model to use") - parser.add_argument("--dataset_name", type=str, default="medical", help="The dataset to use") + parser.add_argument("--dataset_name", type=str, default="novel", help="The dataset to use") parser.add_argument("--llm_model", type=str, default="gpt-4o-mini", help="The LLM model to use") parser.add_argument("--max_workers", type=int, default=16, help="The max number of workers to use") + parser.add_argument("--max_iterations", type=int, default=3, help="The max number of iterations to use") + parser.add_argument("--iteration_threshold", type=float, default=0.4, help="The threshold for iteration") + parser.add_argument("--passage_ratio", type=float, default=2, help="The ratio for passage") + parser.add_argument("--top_k_sentence", type=int, default=3, help="The top k sentence to use") return parser.parse_args() @@ -37,25 +43,31 @@ def load_embedding_model(embedding_model): return embedding_model def main(): + time = datetime.now() + time_str = time.strftime("%Y-%m-%d_%H-%M-%S") args = parse_arguments() embedding_model = load_embedding_model(args.embedding_model) questions,passages = load_dataset(args.dataset_name) - setup_logging(f"results/{args.dataset_name}/log.txt") + setup_logging(f"results/{args.dataset_name}/{time_str}/log.txt") llm_model = LLM_Model(args.llm_model) config = LinearRAGConfig( dataset_name=args.dataset_name, embedding_model=embedding_model, spacy_model=args.spacy_model, max_workers=args.max_workers, - llm_model=llm_model + llm_model=llm_model, + max_iterations=args.max_iterations, + iteration_threshold=args.iteration_threshold, + passage_ratio=args.passage_ratio, + top_k_sentence=args.top_k_sentence ) rag_model = LinearRAG(global_config=config) rag_model.index(passages) questions = rag_model.qa(questions) - os.makedirs(f"results/{args.dataset_name}", exist_ok=True) - with open(f"results/{args.dataset_name}/predictions.json", "w", encoding="utf-8") as f: + os.makedirs(f"results/{args.dataset_name}/{time_str}", exist_ok=True) + with open(f"results/{args.dataset_name}/{time_str}/predictions.json", "w", encoding="utf-8") as f: json.dump(questions, f, ensure_ascii=False, indent=4) - evaluator = Evaluator(llm_model=llm_model, predictions_path=f"results/{args.dataset_name}/predictions.json") + evaluator = Evaluator(llm_model=llm_model, predictions_path=f"results/{args.dataset_name}/{time_str}/predictions.json") evaluator.evaluate(max_workers=args.max_workers) if __name__ == "__main__": main()
\ No newline at end of file |
