summaryrefslogtreecommitdiff
path: root/scripts/diagnose_oom.py
blob: 22de3f9651d875667f3a097240686aa6ed77dfb5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from transformers import AutoModel, AutoTokenizer
import os
import psutil
import time
import sys
import gc

def log_mem(msg):
    mem = psutil.Process().memory_info().rss / (1024**3)
    if torch.cuda.is_available():
        gpu = torch.cuda.memory_allocated() / (1024**3)
        gpu_res = torch.cuda.memory_reserved() / (1024**3)
        print(f"[{msg}] RAM: {mem:.2f}GB | GPU Alloc: {gpu:.2f}GB | GPU Res: {gpu_res:.2f}GB")
    else:
        print(f"[{msg}] RAM: {mem:.2f}GB | GPU: N/A")
    sys.stdout.flush()

def main():
    print("--- Diagnostic Script ---")
    log_mem("Start")
    
    model_path = "models/qwen3-embedding-8b"
    print(f"Model path: {model_path}")
    
    # Check config
    import yaml
    try:
        with open("configs/local_models.yaml", "r") as f:
            cfg = yaml.safe_load(f)
            print("Config loaded from local_models.yaml:")
            print(cfg['models']['embedding']['qwen3'])
    except Exception as e:
        print(f"Could not load config: {e}")

    # Explicit garbage collection
    gc.collect()
    torch.cuda.empty_cache()
    log_mem("Pre-Load")

    print("Loading Tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=False)
    log_mem("Tokenizer Loaded")

    print("Loading Model (trust_remote_code=False)...")
    try:
        # Load with low_cpu_mem_usage=True explicit (though auto/cuda usually does it)
        model = AutoModel.from_pretrained(
            model_path,
            device_map="cuda:0",
            torch_dtype=torch.bfloat16,
            trust_remote_code=False,
            low_cpu_mem_usage=True 
        )
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Model load failed: {e}")
        return

    log_mem("Model Loaded")
    
    print("Testing forward pass with small input...")
    input_text = "Hello world"
    inputs = tokenizer(input_text, return_tensors="pt").to("cuda:0")
    
    try:
        with torch.no_grad():
            outputs = model(**inputs)
        print("Forward pass success.")
        print(f"Output shape: {outputs.last_hidden_state.shape}")
    except Exception as e:
        print(f"Forward pass failed: {e}")

    log_mem("End")

if __name__ == "__main__":
    main()