blob: 22de3f9651d875667f3a097240686aa6ed77dfb5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
import torch
from transformers import AutoModel, AutoTokenizer
import os
import psutil
import time
import sys
import gc
def log_mem(msg):
mem = psutil.Process().memory_info().rss / (1024**3)
if torch.cuda.is_available():
gpu = torch.cuda.memory_allocated() / (1024**3)
gpu_res = torch.cuda.memory_reserved() / (1024**3)
print(f"[{msg}] RAM: {mem:.2f}GB | GPU Alloc: {gpu:.2f}GB | GPU Res: {gpu_res:.2f}GB")
else:
print(f"[{msg}] RAM: {mem:.2f}GB | GPU: N/A")
sys.stdout.flush()
def main():
print("--- Diagnostic Script ---")
log_mem("Start")
model_path = "models/qwen3-embedding-8b"
print(f"Model path: {model_path}")
# Check config
import yaml
try:
with open("configs/local_models.yaml", "r") as f:
cfg = yaml.safe_load(f)
print("Config loaded from local_models.yaml:")
print(cfg['models']['embedding']['qwen3'])
except Exception as e:
print(f"Could not load config: {e}")
# Explicit garbage collection
gc.collect()
torch.cuda.empty_cache()
log_mem("Pre-Load")
print("Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=False)
log_mem("Tokenizer Loaded")
print("Loading Model (trust_remote_code=False)...")
try:
# Load with low_cpu_mem_usage=True explicit (though auto/cuda usually does it)
model = AutoModel.from_pretrained(
model_path,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
trust_remote_code=False,
low_cpu_mem_usage=True
)
print("Model loaded successfully.")
except Exception as e:
print(f"Model load failed: {e}")
return
log_mem("Model Loaded")
print("Testing forward pass with small input...")
input_text = "Hello world"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda:0")
try:
with torch.no_grad():
outputs = model(**inputs)
print("Forward pass success.")
print(f"Output shape: {outputs.last_hidden_state.shape}")
except Exception as e:
print(f"Forward pass failed: {e}")
log_mem("End")
if __name__ == "__main__":
main()
|