summaryrefslogtreecommitdiff
path: root/scripts/fetch_papers.py
diff options
context:
space:
mode:
authorblackhao <13851610112@163.com>2025-03-30 18:12:37 -0500
committerblackhao <13851610112@163.com>2025-03-30 18:12:37 -0500
commit25eeb436f1c5f409855022131ee1dc8700a698f3 (patch)
treeb649a3e4aa5d8cd514ba074e7c2952d8fb05c271 /scripts/fetch_papers.py
parent150259fb5cd1c6c13802ccf74a758327c2101358 (diff)
api
Diffstat (limited to 'scripts/fetch_papers.py')
-rw-r--r--scripts/fetch_papers.py172
1 files changed, 96 insertions, 76 deletions
diff --git a/scripts/fetch_papers.py b/scripts/fetch_papers.py
index 98e6eb0..444f67d 100644
--- a/scripts/fetch_papers.py
+++ b/scripts/fetch_papers.py
@@ -4,68 +4,80 @@ import feedparser
import datetime
from github import Github
-# 你想要的分类列表
+#####################
+# 1. 配置/常量
+#####################
+
ALLOWED_CATEGORIES = [
- "cs.AI", # Artificial Intelligence
- "cs.CL", # Computation and Language
- "cs.CV", # Computer Vision and Pattern Recognition
- "cs.LG", # Machine Learning
- "cs.NE", # Neural and Evolutionary Computing
- "cs.RO", # Robotics
- "cs.IR", # Information Retrieval
- "stat.ML" # Stat.ML
+ "cs.AI", "cs.CL", "cs.CV", "cs.LG", "cs.NE", "cs.RO",
+ "cs.IR", "stat.ML"
]
-def advanced_filter(entry):
- """
- 判断一篇论文是否含有正面关键词组合(bias/fairness + LLM/transformer/GPT等),
- 且不包含负面关键词(统计、物理、电路等)。
- """
- import re
-
- # 减少重复处理,先统一转小写
- title = getattr(entry, 'title', '').lower()
- summary = getattr(entry, 'summary', '').lower()
- full_text = title + " " + summary
-
- # 1) 正面关键词
- # - 必须含有 "bias" 或 "fairness"(泛泛概念)
- # - 且含有至少一个模型相关关键词
- general_terms = ["bias", "fairness"]
- model_terms = ["llm", "language model", "transformer", "gpt", "nlp",
- "pretrained", "embedding", "generation", "alignment", "ai"]
-
- # 2) 负面关键词(排除统计、物理、电路等无关方向)
- negative_terms = [
- "estimation", "variance", "statistical", "sample", "sensor", "circuit",
- "quantum", "physics", "electronics", "hardware", "transistor", "amplifier"
- ]
-
- # 检查正面关键词
- has_general = any(term in full_text for term in general_terms)
- has_model = any(term in full_text for term in model_terms)
-
- # 检查负面关键词(命中则排除)
- has_negative = any(term in full_text for term in negative_terms)
+API_URL = "https://uiuc.chat/api/chat-api/chat"
+API_KEY = os.getenv("UIUC_API_KEY") # 你自己的密钥
+MODEL_NAME = "qwen2.5:14b-instruct-fp16" # 你的model
- # 只有同时满足“general + model”并且“无负面”才返回True
- return (has_general and has_model) and (not has_negative)
+SYSTEM_PROMPT = (
+ "Based on the given title and abstract, please determine if the paper "
+ "is relevant to both language models and bias (or fairness). "
+ "If yes, respond 1; otherwise respond 0."
+)
+#####################
+# 2. 函数: 调用外部API 判别
+#####################
-def fetch_papers_wide_then_filter(days=1):
+def is_relevant_by_api(title, abstract):
"""
- 从 arXiv 中抓取过去 N 天内提交的所有论文(限制时间),然后在本地过滤:
- 1) 只保留 tags 中包含 ALLOWED_CATEGORIES(若论文有多分类,只要有任意一个符合就OK)
- 2) 用 advanced_filter() 检查标题或摘要是否满足要求
+ 调用外部API, 给一段title+abstract, 返回 True(1) or False(0).
+ """
+ headers = {"Content-Type": "application/json"}
+ data = {
+ "model": MODEL_NAME,
+ "messages": [
+ {
+ "role": "system",
+ "content": SYSTEM_PROMPT
+ },
+ {
+ "role": "user",
+ # 填入我们的标题+摘要, 作为"content"
+ "content": f"Title: {title}\nAbstract: {abstract}"
+ }
+ ],
+ "api_key": API_KEY,
+ "course_name": "llm-bias-papers",
+ "stream": False,
+ "temperature": 0.0,
+ "retrieval_only": False
+ }
+ try:
+ resp = requests.post(API_URL, headers=headers, json=data, timeout=30)
+ resp.raise_for_status()
+ # resp.json() 应该包含 'message'
+ response_msg = resp.json().get('message','')
+ # 如果 message="1", 就 True, 否则 False
+ return (response_msg.strip() == "1")
+ except requests.RequestException as e:
+ print("[ERROR] calling external API:", e)
+ # 如果出错, 默认返回 False or do something
+ return False
+
+#####################
+# 3. 函数: 抓论文, 调API判别
+#####################
+
+def fetch_arxiv_papers_with_api(days=1):
+ """
+ 宽松抓 + 本地分类过滤 + 外部API做判别
"""
now_utc = datetime.datetime.now(datetime.timezone.utc)
start_utc = now_utc - datetime.timedelta(days=days)
start_str = start_utc.strftime("%Y%m%d%H%M")
end_str = now_utc.strftime("%Y%m%d%H%M")
- print(f"[DEBUG] date range (UTC): {start_str} ~ {end_str} (past {days} days)")
- # 构造搜索query,仅用时间
+ print(f"[DEBUG] date range (UTC): {start_str} ~ {end_str}, days={days}")
search_query = f"submittedDate:[{start_str} TO {end_str}]"
base_url = "http://export.arxiv.org/api/query"
@@ -82,59 +94,65 @@ def fetch_papers_wide_then_filter(days=1):
"max_results": step
}
print(f"[DEBUG] fetching: {start} -> {start+step}")
- resp = requests.get(base_url, params=params)
- if resp.status_code != 200:
- print("[ERROR] HTTP Status:", resp.status_code)
+ try:
+ resp = requests.get(base_url, params=params, timeout=30)
+ if resp.status_code != 200:
+ print("[ERROR] HTTP Status:", resp.status_code)
+ break
+ feed = feedparser.parse(resp.content)
+ except Exception as e:
+ print("[ERROR] fetching arXiv:", e)
break
- feed = feedparser.parse(resp.content)
batch = feed.entries
got_count = len(batch)
- print(f"[DEBUG] got {got_count} entries in this batch")
+ print(f"[DEBUG] got {got_count} entries in this batch.")
if got_count == 0:
- # 没有更多了
break
all_entries.extend(batch)
start += step
-
- # 安全上限
if start >= 3000:
print("[DEBUG] reached 3000, stop.")
break
print(f"[DEBUG] total retrieved in date range: {len(all_entries)}")
- # -- 本地过滤 --
matched = []
- for e in all_entries:
- if hasattr(e, 'tags'):
- # e.tags: a list of objects with .term
- categories = [t.term for t in e.tags]
+ for entry in all_entries:
+ title = getattr(entry, 'title', '')
+ summary = getattr(entry, 'summary', '')
+ published = getattr(entry, 'published', '')
+ link = getattr(entry, 'link', '')
+ # 先检查分类
+ if hasattr(entry, 'tags'):
+ categories = [t.term for t in entry.tags]
else:
categories = []
- # 1) 是否属于 ALLOWED_CATEGORIES
+ # 是否有至少一个分类在 ALLOWED_CATEGORIES 里
in_allowed_cat = any(cat in ALLOWED_CATEGORIES for cat in categories)
if not in_allowed_cat:
continue
- # 2) 更精准的组合式关键词筛选
- if advanced_filter(e):
+ # 调用外部 API 判别: relevant or not
+ relevant = is_relevant_by_api(title, summary)
+ if relevant:
matched.append({
- "title": e.title,
- "published": e.published,
- "link": e.link,
+ "title": title,
+ "published": published,
+ "link": link,
"categories": categories
})
- print(f"[DEBUG] matched {len(matched)} papers after local filtering (categories + advanced_filter)")
+ print(f"[DEBUG] matched {len(matched)} papers after external API check.")
return matched
+#####################
+# 4. 函数: update README
+#####################
+
def update_readme_in_repo(papers, token, repo_name):
- """
- 将匹配到的论文列表追加到目标repo的 README.md (main分支)
- """
if not papers:
print("[INFO] No matched papers, skip README update.")
return
@@ -142,7 +160,7 @@ def update_readme_in_repo(papers, token, repo_name):
g = Github(token)
repo = g.get_repo(repo_name)
- # 读取现有 README
+ # 获取 README
readme_file = repo.get_contents("README.md", ref="main")
old_content = readme_file.decoded_content.decode("utf-8")
@@ -166,13 +184,15 @@ def update_readme_in_repo(papers, token, repo_name):
)
print(f"[INFO] README updated with {len(papers)} papers.")
+#####################
+# 5. main
+#####################
+
def main():
- # 1) 抓取过去1天
days = 1
- papers = fetch_papers_wide_then_filter(days=days)
- print(f"\n[RESULT] matched {len(papers)} papers. Will update README if not empty.")
+ papers = fetch_arxiv_papers_with_api(days=days)
+ print(f"[RESULT] matched {len(papers)} papers. Will update README if not empty.")
- # 2) 更新README
github_token = os.getenv("TARGET_REPO_TOKEN")
target_repo_name = os.getenv("TARGET_REPO_NAME")