summaryrefslogtreecommitdiff
path: root/scripts/fetch_papers.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/fetch_papers.py')
-rw-r--r--scripts/fetch_papers.py87
1 files changed, 56 insertions, 31 deletions
diff --git a/scripts/fetch_papers.py b/scripts/fetch_papers.py
index 713e6b4..1b86d58 100644
--- a/scripts/fetch_papers.py
+++ b/scripts/fetch_papers.py
@@ -4,13 +4,28 @@ import feedparser
import datetime
from github import Github
+# 你想要的分类列表
+ALLOWED_CATEGORIES = [
+ "cs.AI", # Artificial Intelligence
+ "cs.CL", # Computation and Language
+ "cs.CV", # Computer Vision and Pattern Recognition
+ "cs.LG", # Machine Learning
+ "cs.NE", # Neural and Evolutionary Computing
+ "cs.RO", # Robotics
+ "cs.CY", # Computers and Society
+ "cs.HC", # Human-Computer Interaction
+ "cs.IR", # Information Retrieval
+ "cs.GL", # General Literature
+ "cs.SI", # Social and Information Networks
+ "stat.ML" # Stat.ML
+]
+
def fetch_papers_wide_then_filter(days=1, keywords=None):
"""
- 抓过去 N 天的论文(只限制 submittedDate),然后本地判断:
- - 是否 cs.* 或 stat.*
- - 标题/摘要是否含 keywords
- 返回一个列表,每个元素是字典:
- { 'title':..., 'published':..., 'link':..., 'categories':[...] }
+ 从 arXiv 中抓取过去 N 天内提交的所有论文(只限制时间 submittedDate),
+ 然后在本地过滤:
+ 1) 只保留 tags 中包含 ALLOWED_CATEGORIES(若论文有多分类,只要有任意一个符合就OK)
+ 2) 标题或摘要里包含指定关键词
"""
if keywords is None:
keywords = ["bias", "fairness"]
@@ -20,15 +35,16 @@ def fetch_papers_wide_then_filter(days=1, keywords=None):
start_str = start_utc.strftime("%Y%m%d%H%M")
end_str = now_utc.strftime("%Y%m%d%H%M")
+ print(f"[DEBUG] date range (UTC): {start_str} ~ {end_str} (past {days} days)")
+ # 构造 search_query,仅用时间
search_query = f"submittedDate:[{start_str} TO {end_str}]"
- base_url = "http://export.arxiv.org/api/query"
+ base_url = "http://export.arxiv.org/api/query"
step = 100
start = 0
all_entries = []
- print(f"[DEBUG] Time range: {start_str} ~ {end_str}, days={days}")
while True:
params = {
"search_query": search_query,
@@ -38,45 +54,54 @@ def fetch_papers_wide_then_filter(days=1, keywords=None):
"max_results": step
}
print(f"[DEBUG] fetching: {start} -> {start+step}")
- r = requests.get(base_url, params=params)
- if r.status_code != 200:
- print("[ERROR] HTTP status:", r.status_code)
+ resp = requests.get(base_url, params=params)
+ if resp.status_code != 200:
+ print("[ERROR] HTTP Status:", resp.status_code)
break
- feed = feedparser.parse(r.content)
- got = len(feed.entries)
- print(f"[DEBUG] got {got} entries this batch.")
- if got == 0:
+ feed = feedparser.parse(resp.content)
+ batch = feed.entries
+ got_count = len(batch)
+ print(f"[DEBUG] got {got_count} entries in this batch")
+ if got_count == 0:
+ # 没有更多了
break
- all_entries.extend(feed.entries)
+ all_entries.extend(batch)
start += step
+ # 安全上限
if start >= 3000:
print("[DEBUG] reached 3000, stop.")
break
- print(f"[DEBUG] total in date range: {len(all_entries)}")
+ print(f"[DEBUG] total retrieved in date range: {len(all_entries)}")
+ # -- 本地过滤 --
matched = []
for e in all_entries:
title = getattr(e, 'title', '')
summary = getattr(e, 'summary', '')
published = getattr(e, 'published', '')
link = getattr(e, 'link', '')
+
if hasattr(e, 'tags'):
+ # e.tags: a list of objects with .term
categories = [t.term for t in e.tags]
else:
categories = []
- # 判定分类
- has_cs_stat = any(c.startswith("cs.") or c.startswith("stat.") for c in categories)
- if not has_cs_stat:
+ # 1) 是否属于 ALLOWED_CATEGORIES
+ # 有些论文有多分类,只要其中一个在 ALLOWED_CATEGORIES 里就OK
+ # 例如 "cs.IR", "cs.AI"
+ in_allowed_cat = any(cat in ALLOWED_CATEGORIES for cat in categories)
+ if not in_allowed_cat:
continue
- # 判定关键词
+ # 2) 是否含关键词
text_lower = (title + " " + summary).lower()
- if any(kw.lower() in text_lower for kw in keywords):
+ has_keyword = any(kw.lower() in text_lower for kw in keywords)
+ if has_keyword:
matched.append({
"title": title,
"published": published,
@@ -84,12 +109,12 @@ def fetch_papers_wide_then_filter(days=1, keywords=None):
"categories": categories
})
- print(f"[DEBUG] matched {len(matched)} papers after local filter (cs./stat.+keywords)")
+ print(f"[DEBUG] matched {len(matched)} papers after local filtering (categories + keywords)")
return matched
def update_readme_in_repo(papers, token, repo_name):
"""
- 将匹配到的论文列表写入目标repo的 README.md
+ 将匹配到的论文列表追加到目标repo的 README.md (main分支)
"""
if not papers:
print("[INFO] No matched papers, skip README update.")
@@ -98,9 +123,9 @@ def update_readme_in_repo(papers, token, repo_name):
g = Github(token)
repo = g.get_repo(repo_name)
- # 获取 README 内容
+ # 读取现有 README
readme_file = repo.get_contents("README.md", ref="main")
- readme_content = readme_file.decoded_content.decode("utf-8")
+ old_content = readme_file.decoded_content.decode("utf-8")
now_utc_str = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
new_section = f"\n\n### Auto-captured papers on {now_utc_str}\n"
@@ -110,7 +135,7 @@ def update_readme_in_repo(papers, token, repo_name):
new_section += f" - Categories: {cat_str} \n"
new_section += f" - Link: {p['link']}\n\n"
- updated_content = readme_content + new_section
+ updated_content = old_content + new_section
commit_msg = f"Auto update README with {len(papers)} new papers"
repo.update_file(
@@ -123,16 +148,16 @@ def update_readme_in_repo(papers, token, repo_name):
print(f"[INFO] README updated with {len(papers)} papers.")
def main():
- # 1. 获取过去3天, keywords=["bias","fairness"] 的论文
- days = 1
+ # 1) 抓取过去3天, 关键词=["bias","fairness"]
+ days = 3
keywords = ["bias", "fairness"]
papers = fetch_papers_wide_then_filter(days=days, keywords=keywords)
- print(f"[RESULT] matched {len(papers)} papers. Now let's update README in target repo if any.")
+ print(f"\n[RESULT] matched {len(papers)} papers. Will update README if not empty.")
- # 2. 如果有匹配论文,更新 README
- # 需要在 secrets 或 env 里获取 token, repo name
+ # 2) 更新README
github_token = os.getenv("TARGET_REPO_TOKEN")
target_repo_name = os.getenv("TARGET_REPO_NAME")
+
if not github_token or not target_repo_name:
print("[ERROR] Missing environment variables: TARGET_REPO_TOKEN / TARGET_REPO_NAME.")
return