diff options
| author | LuyaoZhuang <zhuangluyao523@gmail.com> | 2025-10-26 05:09:57 -0400 |
|---|---|---|
| committer | LuyaoZhuang <zhuangluyao523@gmail.com> | 2025-10-26 05:09:57 -0400 |
| commit | ccff87c15263d1d63235643d54322b991366952e (patch) | |
| tree | 71051f88bc8df1b3d40a65df6ca72e4586a1d1aa /src/utils.py | |
commit
Diffstat (limited to 'src/utils.py')
| -rw-r--r-- | src/utils.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..50ecc34 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,77 @@ +from hashlib import md5 +from dataclasses import dataclass, field +from typing import List, Dict +import httpx +from openai import OpenAI +from collections import defaultdict +import multiprocessing as mp +import re +import string +import logging +import numpy as np +import os + +def compute_mdhash_id(content: str, prefix: str = "") -> str: + return prefix + md5(content.encode()).hexdigest() + +class LLM_Model: + def __init__(self, llm_model): + http_client = httpx.Client(timeout=60.0, trust_env=False) + self.openai_client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_BASE_URL"), + http_client=http_client + ) + self.llm_config = { + "model": llm_model, + "max_tokens": 2000, + "temperature": 0, + } + def infer(self, messages): + response = self.openai_client.chat.completions.create(**self.llm_config,messages=messages) + return response.choices[0].message.content + + + +def normalize_answer(s): + if s is None: + return "" + if not isinstance(s, str): + s = str(s) + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + def white_space_fix(text): + return " ".join(text.split()) + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + def lower(text): + return text.lower() + return white_space_fix(remove_articles(remove_punc(lower(s)))) + +def setup_logging(log_file): + log_format = '%(asctime)s - %(levelname)s - %(message)s' + handlers = [logging.StreamHandler()] + os.makedirs(os.path.dirname(log_file), exist_ok=True) + handlers.append(logging.FileHandler(log_file, mode='a', encoding='utf-8')) + logging.basicConfig( + level=logging.INFO, + format=log_format, + handlers=handlers, + force=True + ) + # Suppress noisy HTTP request logs (e.g., 401 Unauthorized) from httpx/openai + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + logging.getLogger("openai").setLevel(logging.WARNING) + +def min_max_normalize(x): + min_val = np.min(x) + max_val = np.max(x) + range_val = max_val - min_val + + # Handle the case where all values are the same (range is zero) + if range_val == 0: + return np.ones_like(x) # Return an array of ones with the same shape as x + + return (x - min_val) / range_val |
