summaryrefslogtreecommitdiff
path: root/kg_rag/run_setup.py
blob: 04c856cef2be28da8222657986985dfda78b7e63 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from kg_rag.utility import config_data

def download_llama(method):
    from kg_rag.utility import llama_model
    try:
        llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], method=method)
        print("Model is successfully downloaded to the provided cache directory!")
    except:
        print("Model is not downloaded! Make sure the above mentioned conditions are satisfied")
        

print("")
print("Starting to set up KG-RAG ...")
print("")

# user_input = input("Did you update the config.yaml file with all necessary configurations (such as GPT .env path, vectorDB file paths, other file paths)? Enter Y or N: ")
# print("")
# if user_input == "Y":
if True:
    print("Checking disease vectorDB ...")
    print("The current VECTOR_DB_PATH is ", config_data["VECTOR_DB_PATH"])
    try:
        if os.path.exists(config_data["VECTOR_DB_PATH"]):
            print("vectorDB already exists!")
        else:
            print("Creating vectorDB ...")
            from kg_rag.vectorDB.create_vectordb import create_vectordb
            create_vectordb()
            print("Congratulations! The disease database is completed.")
    except:
        print("Double check the path that was given in VECTOR_DB_PATH of config.yaml file.")
    '''
    print("")
    user_input_1 = input("Do you want to install Llama model? Enter Y or N: ")
    if user_input_1 == "Y":
        user_input_2 = input("Did you update the config.yaml file with proper configuration for downloading Llama model? Enter Y or N: ")
        if user_input_2 == "Y":
            user_input_3 = input("Are you using official Llama model from Meta? Enter Y or N: ")
            if user_input_3 == "Y":
                user_input_4 = input("Did you get access to use the model? Enter Y or N: ")
                if user_input_4 == "Y":
                    download_llama()
                    print("Congratulations! Setup is completed.")
                else:
                    print("Aborting!")
            else:
                download_llama(method='method-1')
                user_input_5 = input("Did you get a message like 'Model is not downloaded!'?  Enter Y or N: ")
                if user_input_5 == "N":                
                    print("Congratulations! Setup is completed.")
                else:
                    download_llama(method='method-2')
                    user_input_6 = input("Did you get a message like 'Model is not downloaded!'?  Enter Y or N: ")
                    if user_input_6 == "N":                        
                        print("""
                        IMPORTANT : 
                        Llama model was downloaded using 'LlamaTokenizer' instead of 'AutoTokenizer' method. 
                        So, when you run text generation script, please provide an extra command line argument '-m method-2'.
                        For example:
                            python -m kg_rag.rag_based_generation.Llama.text_generation -m method-2
                        """)
                        print("Congratulations! Setup is completed.")
                    else:
                        print("We have now tried two methods to download Llama. If they both do not work, then please check the Llama configuration requirement in the huggingface model card page. Aborting!")
        else:
            print("Aborting!")
    else:
        print("No problem. Llama will get installed on-the-fly when you run the model for the first time.")
        print("Congratulations! Setup is completed.")
    '''
else:
    print("As the first step, update config.yaml file and then run this python script again.")