diff options
Diffstat (limited to 'scripts/fetch_papers.py')
| -rw-r--r-- | scripts/fetch_papers.py | 230 |
1 files changed, 214 insertions, 16 deletions
diff --git a/scripts/fetch_papers.py b/scripts/fetch_papers.py index 3db80c7..fd3e628 100644 --- a/scripts/fetch_papers.py +++ b/scripts/fetch_papers.py @@ -24,7 +24,11 @@ import feedparser from datetime import datetime, timezone, timedelta from typing import List, Dict, Optional, Tuple from github import Github -from openai import OpenAI +from openai import OpenAI, AsyncOpenAI +import asyncio +import aiohttp +from concurrent.futures import ThreadPoolExecutor +import time # Configure logging logging.basicConfig( @@ -54,17 +58,24 @@ CS_CATEGORIES = [ "stat.ML" # Machine Learning (Statistics) ] -GPT_SYSTEM_PROMPT = """You are an expert researcher in AI/ML bias and fairness. +GPT_SYSTEM_PROMPT = """You are an expert researcher in AI/ML bias, fairness, and social good applications. -Your task is to analyze a paper's title and abstract to determine if it's relevant to LLM (Large Language Model) bias and fairness research. +Your task is to analyze a paper's title and abstract to determine if it's relevant to bias and fairness research with social good implications. A paper is relevant if it discusses: -- Bias in large language models, generative AI, or foundation models -- Fairness issues in NLP models or text generation -- Ethical concerns with language models -- Demographic bias in AI systems -- Alignment and safety of language models -- Bias evaluation or mitigation in NLP +- Bias, fairness, or discrimination in AI/ML systems with societal impact +- Algorithmic fairness in healthcare, education, criminal justice, hiring, or finance +- Demographic bias affecting marginalized or underrepresented groups +- Data bias and its social consequences +- Ethical AI and responsible AI deployment in society +- AI safety and alignment with human values and social welfare +- Bias evaluation, auditing, or mitigation in real-world applications +- Representation and inclusion in AI systems and datasets +- Social implications of AI bias (e.g., perpetuating inequality) +- Fairness in recommendation systems, search engines, or content moderation +- Bias in computer vision, NLP, or other AI domains affecting people + +The focus is on research that addresses how AI bias impacts society, vulnerable populations, or social justice, rather than purely technical ML advances without clear social relevance. Respond with exactly "1" if the paper is relevant, or "0" if it's not relevant. Do not include any other text in your response.""" @@ -76,6 +87,7 @@ class ArxivPaperFetcher: def __init__(self, openai_api_key: str): """Initialize the fetcher with OpenAI API key.""" self.openai_client = OpenAI(api_key=openai_api_key) + self.async_openai_client = AsyncOpenAI(api_key=openai_api_key) self.session = requests.Session() self.session.headers.update({ 'User-Agent': 'PaperFetcher/1.0 (https://github.com/YurenHao0426/PaperFetcher)' @@ -257,12 +269,15 @@ class ArxivPaperFetcher: "categories": [tag.term for tag in entry.tags] if hasattr(entry, 'tags') else [] } - def filter_papers_with_gpt(self, papers: List[Dict]) -> List[Dict]: + def filter_papers_with_gpt(self, papers: List[Dict], use_parallel: bool = True, + max_concurrent: int = 16) -> List[Dict]: """ Filter papers using GPT-4o to identify bias-related research. Args: papers: List of paper dictionaries + use_parallel: Whether to use parallel processing (default: True) + max_concurrent: Maximum concurrent requests (default: 16) Returns: List of relevant papers @@ -271,6 +286,15 @@ class ArxivPaperFetcher: logger.warning("⚠️ 没有论文需要过滤!") return [] + if use_parallel and len(papers) > 5: + logger.info(f"🚀 使用并行模式处理 {len(papers)} 篇论文 (最大并发: {max_concurrent})") + return self._filter_papers_parallel(papers, max_concurrent) + else: + logger.info(f"🔄 使用串行模式处理 {len(papers)} 篇论文") + return self._filter_papers_sequential(papers) + + def _filter_papers_sequential(self, papers: List[Dict]) -> List[Dict]: + """Serial processing of papers (original method).""" logger.info(f"🤖 开始使用GPT-4o过滤论文...") logger.info(f"📝 待处理论文数量: {len(papers)} 篇") @@ -304,6 +328,111 @@ class ArxivPaperFetcher: return relevant_papers + def _filter_papers_parallel(self, papers: List[Dict], max_concurrent: int = 16) -> List[Dict]: + """Parallel processing of papers using asyncio.""" + try: + # 检查是否已有事件循环 + loop = asyncio.get_event_loop() + if loop.is_running(): + # 在已有事件循环中运行 + import nest_asyncio + nest_asyncio.apply() + return loop.run_until_complete(self._async_filter_papers(papers, max_concurrent)) + else: + # 创建新的事件循环 + return asyncio.run(self._async_filter_papers(papers, max_concurrent)) + except Exception as e: + logger.error(f"❌ 并行处理失败: {e}") + logger.info("🔄 回退到串行处理模式...") + return self._filter_papers_sequential(papers) + + async def _async_filter_papers(self, papers: List[Dict], max_concurrent: int) -> List[Dict]: + """Async implementation of paper filtering.""" + logger.info(f"🤖 开始异步GPT-4o过滤...") + logger.info(f"📝 待处理论文数量: {len(papers)} 篇") + + # 创建信号量控制并发数 + semaphore = asyncio.Semaphore(max_concurrent) + + # 创建所有任务 + tasks = [] + for i, paper in enumerate(papers): + task = self._check_paper_relevance_async(paper, semaphore, i + 1, len(papers)) + tasks.append(task) + + # 并行执行所有任务 + start_time = time.time() + results = await asyncio.gather(*tasks, return_exceptions=True) + total_time = time.time() - start_time + + # 处理结果 + relevant_papers = [] + successful_count = 0 + error_count = 0 + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"❌ 第 {i+1} 篇论文处理出错: {result}") + error_count += 1 + elif isinstance(result, tuple): + is_relevant, paper = result + successful_count += 1 + if is_relevant: + relevant_papers.append(paper) + logger.debug(f"✅ 第 {i+1} 篇论文 [相关]: {paper['title'][:60]}...") + else: + logger.debug(f"❌ 第 {i+1} 篇论文 [不相关]: {paper['title'][:60]}...") + + # 显示最终统计 + logger.info(f"🎯 并行GPT-4o过滤完成!") + logger.info(f" - 总处理时间: {total_time:.1f} 秒") + logger.info(f" - 平均每篇: {total_time/len(papers):.2f} 秒") + logger.info(f" - 成功处理: {successful_count} 篇论文") + logger.info(f" - 处理错误: {error_count} 篇论文") + logger.info(f" - 发现相关: {len(relevant_papers)} 篇论文") + + if successful_count > 0: + logger.info(f" - 相关比例: {len(relevant_papers)/successful_count*100:.1f}%") + + # 估算加速效果 + estimated_serial_time = len(papers) * 2.0 # 估计串行处理每篇需要2秒 + speedup = estimated_serial_time / total_time if total_time > 0 else 1 + logger.info(f" - 预估加速: {speedup:.1f}x") + + return relevant_papers + + async def _check_paper_relevance_async(self, paper: Dict, semaphore: asyncio.Semaphore, + index: int, total: int) -> tuple: + """Async version of paper relevance checking.""" + async with semaphore: + try: + # 显示进度(每10篇显示一次) + if index % 10 == 0: + logger.info(f"📊 并行进度: {index}/{total} 篇论文处理中...") + + prompt = f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}" + + response = await self.async_openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": GPT_SYSTEM_PROMPT}, + {"role": "user", "content": prompt} + ], + temperature=0, + max_tokens=1 + ) + + result = response.choices[0].message.content.strip() + is_relevant = result == "1" + + logger.debug(f"GPT-4o响应 #{index}: '{result}' -> {'相关' if is_relevant else '不相关'}") + return (is_relevant, paper) + + except Exception as e: + logger.error(f"❌ 第 {index} 篇论文异步处理出错: {e}") + # 返回异常,让上层处理 + raise e + def _check_paper_relevance(self, paper: Dict) -> bool: """Check if a paper is relevant using GPT-4o.""" prompt = f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}" @@ -341,7 +470,13 @@ class ArxivPaperFetcher: if papers: logger.info(f"📋 开始GPT-4o智能过滤阶段...") - return self.filter_papers_with_gpt(papers) + + # 从环境变量获取并行设置 + use_parallel = os.getenv("USE_PARALLEL", "true").lower() == "true" + max_concurrent = int(os.getenv("MAX_CONCURRENT", "16")) + + return self.filter_papers_with_gpt(papers, use_parallel=use_parallel, + max_concurrent=max_concurrent) else: logger.warning("⚠️ 未获取到任何论文,跳过GPT过滤步骤") return [] @@ -359,7 +494,13 @@ class ArxivPaperFetcher: if papers: logger.info(f"📋 开始GPT-4o智能过滤阶段...") - return self.filter_papers_with_gpt(papers) + + # 历史模式默认使用更高的并发数(因为论文数量多) + use_parallel = os.getenv("USE_PARALLEL", "true").lower() == "true" + max_concurrent = int(os.getenv("MAX_CONCURRENT", "25")) # 历史模式默认更高并发 + + return self.filter_papers_with_gpt(papers, use_parallel=use_parallel, + max_concurrent=max_concurrent) else: logger.warning("⚠️ 未获取到任何论文,跳过GPT过滤步骤") return [] @@ -375,7 +516,7 @@ class GitHubUpdater: self.repo = self.github.get_repo(repo_name) def update_readme_with_papers(self, papers: List[Dict], section_title: str = None): - """Update README with new papers.""" + """Update README with new papers in reverse chronological order (newest first).""" if not papers: logger.info("No papers to add to README") return @@ -407,8 +548,21 @@ class GitHubUpdater: new_section += f"**Link:** [arXiv:{paper['arxiv_id']}]({paper['link']})\n\n" new_section += "---\n\n" - # Update README - updated_content = current_content + new_section + # Insert new papers at the beginning to maintain reverse chronological order + # Find the end of the main documentation (after the project description and setup) + insert_position = self._find_papers_insert_position(current_content) + + if insert_position > 0: + # Insert new section after the main documentation but before existing papers + updated_content = (current_content[:insert_position] + + new_section + + current_content[insert_position:]) + logger.info(f"📝 新论文段落插入到README开头,保持时间倒序") + else: + # Fallback: append to end if can't find proper insertion point + updated_content = current_content + new_section + logger.info(f"📝 新论文段落追加到README末尾(找不到合适插入位置)") + commit_message = f"Auto-update: Added {len(papers)} new papers on {datetime.now(timezone.utc).strftime('%Y-%m-%d')}" self.repo.update_file( @@ -419,11 +573,55 @@ class GitHubUpdater: branch="main" ) - logger.info(f"Successfully updated README with {len(papers)} papers") + logger.info(f"✅ 成功更新README,添加了 {len(papers)} 篇论文 (时间倒序)") except Exception as e: logger.error(f"Error updating README: {e}") raise + + def _find_papers_insert_position(self, content: str) -> int: + """Find the best position to insert new papers (after main doc, before existing papers).""" + lines = content.split('\n') + + # Look for patterns that indicate the end of documentation and start of papers + # Search in order of priority + insert_patterns = [ + "**Note**: This tool is designed for academic research purposes", # End of README + "## Papers Updated on", # Existing paper sections + "## Historical", # Historical paper sections + "### ", # Any section that might be a paper title + "---", # Common separator before papers + ] + + for pattern in insert_patterns: + for i, line in enumerate(lines): + if pattern in line: + # Found a good insertion point - insert before this line + # Convert line index to character position + char_position = sum(len(lines[j]) + 1 for j in range(i)) # +1 for newline + return char_position + + # If no patterns found, try to find end of main documentation + # Look for the end of the last documentation section + last_doc_section = -1 + for i, line in enumerate(lines): + if line.startswith('## ') and not line.startswith('## Papers') and not line.startswith('## Historical'): + last_doc_section = i + + if last_doc_section >= 0: + # Find the end of this documentation section + section_end = len(lines) + for i in range(last_doc_section + 1, len(lines)): + if lines[i].startswith('## '): + section_end = i + break + + # Insert after this section + char_position = sum(len(lines[j]) + 1 for j in range(section_end)) + return char_position + + # Final fallback: return 0 to trigger append behavior + return 0 def main(): |
