summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore16
-rw-r--r--readme.md29
-rw-r--r--requirements.txt3
-rw-r--r--run.py4
-rw-r--r--src/LinearRAG.py274
-rw-r--r--src/config.py3
-rw-r--r--src/evaluate.py1
-rw-r--r--src/ner.py3
8 files changed, 302 insertions, 31 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..00d0a46
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+
+results/
+.env
+.venv
+dataset/
+import/
+model/
+scripts/
+
+.DS_Store
+
+# all pycache files, including those in subdirectories
+__pycache__/
+*/__pycache__/
+
+
diff --git a/readme.md b/readme.md
index 00cb8a2..9dbb618 100644
--- a/readme.md
+++ b/readme.md
@@ -1,4 +1,4 @@
-# **LinearRAG: Linear Graph Retrieval-Augmented Generation on Large-scale Corpora**
+# **LinearRAG: Linear Graph Retrieval-Augmented Generation on Large-scale Corpora**
> A relation-free graph construction method for efficient GraphRAG. It eliminates LLM token costs during graph construction, making GraphRAG faster and more efficient than ever.
@@ -17,19 +17,16 @@
---
## 🚀 **Highlights**
-
-- ✅ **Context-Preserving**: Relation-free graph construction, relying on lightweight entity recognition and semantic linking to achieve comprehensive contextual comprehension.
+- ✅ **Context-Preserving**: Relation-free graph construction, relying on lightweight entity recognition and semantic linking to achieve comprehensive contextual comprehension.
- ✅ **Complex Reasoning**: Enables deep retrieval via semantic bridging, achieving multi-hop reasoning in a single retrieval pass without requiring explicit relational graphs.
- ✅ **High Scalability**: Zero LLM token consumption, faster processing speed, and linear time/space complexity.
-
+
<p align="center">
<img src="figure/main_figure.png" width="95%" alt="Framework Overview">
</p>
---
-
## 🎉 **News**
-
- **[2025-10-27]** We release **[LinearRAG](https://github.com/DEEP-PolyU/LinearRAG)**, a relation-free graph construction method for efficient GraphRAG.
- **[2025-06-06]** We release **[GraphRAG-Bench](https://github.com/GraphRAG-Bench/GraphRAG-Benchmark.git)**, the benchmark for evaluating GraphRAG models.
- **[2025-01-21]** We release the **[GraphRAG survey](https://github.com/DEEP-PolyU/Awesome-GraphRAG)**.
@@ -38,7 +35,7 @@
## 🛠️ **Usage**
-### 1️⃣ Install Dependencies
+### 1️⃣ Install Dependencies
**Step 1: Install Python packages**
@@ -53,7 +50,6 @@ python -m spacy download en_core_web_trf
```
> **Note:** For the `medical` dataset, you need to install the scientific/biomedical Spacy model:
-
```bash
pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.3/en_core_sci_scibert-0.5.3.tar.gz
```
@@ -82,6 +78,7 @@ Make sure the embedding model is available at:
model/all-mpnet-base-v2/
```
+
### 2️⃣ Quick Start Example
```bash
@@ -97,6 +94,7 @@ python run.py \
--dataset_name ${DATASET_NAME} \
--llm_model ${LLM_MODEL} \
--max_workers ${MAX_WORKERS}
+ --use_vectorized_retrieval # optional, use vectorized matrix-based retrieval for GPU acceleration if Strong GPU is available, otherwise use BFS iteration.
```
## 🎯 **Performance**
@@ -105,26 +103,17 @@ python run.py \
<img src="figure/generation_results.png" alt="framework" width="1000">
**Main results of end-to-end performance**
-
</div>
<div align="center">
<img src="figure/efficiency_result.png" alt="framework" width="1000">
-
-
-
-![framework](figure/efficiency_result.png)
-
-![framework](figure/efficiency_result.png)
-
**Efficiency and performance comparison.**
-
</div>
+
## 📖 Citation
If you find this work helpful, please consider citing us:
-
```bibtex
@article{zhuang2025linearrag,
title={LinearRAG: Linear Graph Retrieval Augmented Generation on Large-scale Corpora},
@@ -133,9 +122,5 @@ If you find this work helpful, please consider citing us:
year={2025}
}
```
-
-This project is licensed under the GNU General Public License v3.0 ([License](LICENSE.TXT)).
-
## 📬 Contact
-
✉️ Email: zhuangluyao523@gmail.com
diff --git a/requirements.txt b/requirements.txt
index da4c82d..0c4f40c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,10 @@
-httpx[socks]==0.28.1
+httpx==0.25.2
numpy==1.21.0
openai==1.54.5
pandas==1.3.0
python-igraph==0.11.8
scikit-learn==1.3.2
+scipy>=1.7.0
sentence-transformers==2.2.2
spacy==3.6.1
tqdm==4.67.1
diff --git a/run.py b/run.py
index 0d5d017..4ea5b3b 100644
--- a/run.py
+++ b/run.py
@@ -25,6 +25,7 @@ def parse_arguments():
parser.add_argument("--iteration_threshold", type=float, default=0.4, help="The threshold for iteration")
parser.add_argument("--passage_ratio", type=float, default=2, help="The ratio for passage")
parser.add_argument("--top_k_sentence", type=int, default=3, help="The top k sentence to use")
+ parser.add_argument("--use_vectorized_retrieval", action="store_true", help="Use vectorized matrix-based retrieval instead of BFS iteration")
return parser.parse_args()
@@ -59,7 +60,8 @@ def main():
max_iterations=args.max_iterations,
iteration_threshold=args.iteration_threshold,
passage_ratio=args.passage_ratio,
- top_k_sentence=args.top_k_sentence
+ top_k_sentence=args.top_k_sentence,
+ use_vectorized_retrieval=args.use_vectorized_retrieval
)
rag_model = LinearRAG(global_config=config)
rag_model.index(passages)
diff --git a/src/LinearRAG.py b/src/LinearRAG.py
index 0f4f6f8..9435883 100644
--- a/src/LinearRAG.py
+++ b/src/LinearRAG.py
@@ -11,11 +11,22 @@ from src.ner import SpacyNER
import igraph as ig
import re
import logging
+import torch
logger = logging.getLogger(__name__)
+
+
class LinearRAG:
def __init__(self, global_config):
self.config = global_config
logger.info(f"Initializing LinearRAG with config: {self.config}")
+ retrieval_method = "Vectorized Matrix-based" if self.config.use_vectorized_retrieval else "BFS Iteration"
+ logger.info(f"Using retrieval method: {retrieval_method}")
+
+ # Setup device for GPU acceleration
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if self.config.use_vectorized_retrieval:
+ logger.info(f"Using device: {self.device} for vectorized retrieval")
+
self.dataset_name = global_config.dataset_name
self.load_embedding_store()
self.llm_model = self.config.llm_model
@@ -70,7 +81,6 @@ class LinearRAG:
question_info["pred_answer"] = pred_ans
return retrieval_results
-
def retrieve(self, questions):
self.entity_hash_ids = list(self.entity_embedding_store.hash_id_to_text.keys())
self.entity_embeddings = np.array(self.entity_embedding_store.embeddings)
@@ -81,6 +91,19 @@ class LinearRAG:
self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()}
self.vertex_idx_to_node_name = {v.index: v["name"] for v in self.graph.vs if "name" in v.attributes()}
+ # Precompute sparse matrices for vectorized retrieval if needed
+ if self.config.use_vectorized_retrieval:
+ logger.info("Precomputing sparse adjacency matrices for vectorized retrieval...")
+ self._precompute_sparse_matrices()
+ e2s_shape = self.entity_to_sentence_sparse.shape
+ s2e_shape = self.sentence_to_entity_sparse.shape
+ e2s_nnz = self.entity_to_sentence_sparse._nnz()
+ s2e_nnz = self.sentence_to_entity_sparse._nnz()
+ logger.info(f"Matrices built: Entity-Sentence {e2s_shape}, Sentence-Entity {s2e_shape}")
+ logger.info(f"E2S Sparsity: {(1 - e2s_nnz / (e2s_shape[0] * e2s_shape[1])) * 100:.2f}% (nnz={e2s_nnz})")
+ logger.info(f"S2E Sparsity: {(1 - s2e_nnz / (s2e_shape[0] * s2e_shape[1])) * 100:.2f}% (nnz={s2e_nnz})")
+ logger.info(f"Device: {self.device}")
+
retrieval_results = []
for question_info in tqdm(questions, desc="Retrieving"):
question = question_info["question"]
@@ -104,11 +127,67 @@ class LinearRAG:
}
retrieval_results.append(result)
return retrieval_results
-
+
+ def _precompute_sparse_matrices(self):
+ """
+ Precompute and cache sparse adjacency matrices for efficient vectorized retrieval using PyTorch.
+ This is called once at the beginning of retrieve() to avoid rebuilding matrices per query.
+ """
+ num_entities = len(self.entity_hash_ids)
+ num_sentences = len(self.sentence_hash_ids)
-
+ # Build entity-to-sentence matrix (Mention matrix) using COO format
+ entity_to_sentence_indices = []
+ entity_to_sentence_values = []
+
+ for entity_hash_id, sentence_hash_ids in self.entity_hash_id_to_sentence_hash_ids.items():
+ entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id]
+ for sentence_hash_id in sentence_hash_ids:
+ sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id]
+ entity_to_sentence_indices.append([entity_idx, sentence_idx])
+ entity_to_sentence_values.append(1.0)
+
+ # Build sentence-to-entity matrix
+ sentence_to_entity_indices = []
+ sentence_to_entity_values = []
+
+ for sentence_hash_id, entity_hash_ids in self.sentence_hash_id_to_entity_hash_ids.items():
+ sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id]
+ for entity_hash_id in entity_hash_ids:
+ entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id]
+ sentence_to_entity_indices.append([sentence_idx, entity_idx])
+ sentence_to_entity_values.append(1.0)
+
+ # Convert to PyTorch sparse tensors (COO format, then convert to CSR for efficiency)
+ if len(entity_to_sentence_indices) > 0:
+ e2s_indices = torch.tensor(entity_to_sentence_indices, dtype=torch.long).t()
+ e2s_values = torch.tensor(entity_to_sentence_values, dtype=torch.float32)
+ self.entity_to_sentence_sparse = torch.sparse_coo_tensor(
+ e2s_indices, e2s_values, (num_entities, num_sentences), device=self.device
+ ).coalesce()
+ else:
+ self.entity_to_sentence_sparse = torch.sparse_coo_tensor(
+ torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32),
+ (num_entities, num_sentences), device=self.device
+ )
+
+ if len(sentence_to_entity_indices) > 0:
+ s2e_indices = torch.tensor(sentence_to_entity_indices, dtype=torch.long).t()
+ s2e_values = torch.tensor(sentence_to_entity_values, dtype=torch.float32)
+ self.sentence_to_entity_sparse = torch.sparse_coo_tensor(
+ s2e_indices, s2e_values, (num_sentences, num_entities), device=self.device
+ ).coalesce()
+ else:
+ self.sentence_to_entity_sparse = torch.sparse_coo_tensor(
+ torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32),
+ (num_sentences, num_entities), device=self.device
+ )
+
def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores):
- entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ if self.config.use_vectorized_retrieval:
+ entity_weights, actived_entities = self.calculate_entity_scores_vectorized(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ else:
+ entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
passage_weights = self.calculate_passage_scores(question_embedding,actived_entities)
node_weights = entity_weights + passage_weights
ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights)
@@ -176,6 +255,191 @@ class LinearRAG:
iteration += 1
return entity_weights, actived_entities
+ def calculate_entity_scores_vectorized(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ """
+ GPU-accelerated vectorized version using PyTorch sparse tensors.
+ Uses sparse representation for both matrices and entity score vectors for maximum efficiency.
+ Now includes proper dynamic pruning to match BFS behavior:
+ - Sentence deduplication (tracks used sentences)
+ - Per-entity top-k sentence selection
+ - Proper threshold-based pruning
+ """
+ # Initialize entity weights
+ entity_weights = np.zeros(len(self.graph.vs["name"]))
+ num_entities = len(self.entity_hash_ids)
+ num_sentences = len(self.sentence_hash_ids)
+
+ # Compute all sentence similarities with the question at once
+ question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding
+ sentence_similarities_np = np.dot(self.sentence_embeddings, question_emb).flatten()
+
+ # Convert to torch tensors and move to device
+ sentence_similarities = torch.from_numpy(sentence_similarities_np).float().to(self.device)
+
+ # Track used sentences for deduplication (like BFS version)
+ used_sentence_mask = torch.zeros(num_sentences, dtype=torch.bool, device=self.device)
+
+ # Initialize seed entity scores as sparse tensor
+ seed_indices = torch.tensor([[idx] for idx in seed_entity_indices], dtype=torch.long).t()
+ seed_values = torch.tensor(seed_entity_scores, dtype=torch.float32)
+ entity_scores_sparse = torch.sparse_coo_tensor(
+ seed_indices, seed_values, (num_entities,), device=self.device
+ ).coalesce()
+
+ # Also maintain a dense accumulator for total scores
+ entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device)
+ entity_scores_dense.scatter_(0, torch.tensor(seed_entity_indices, device=self.device),
+ torch.tensor(seed_entity_scores, dtype=torch.float32, device=self.device))
+
+ # Initialize actived_entities
+ actived_entities = {}
+ for seed_entity_idx, seed_entity, seed_entity_hash_id, seed_entity_score in zip(
+ seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores
+ ):
+ actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 0)
+ seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id]
+ entity_weights[seed_entity_node_idx] = seed_entity_score
+
+ current_entity_scores_sparse = entity_scores_sparse
+
+ # Iterative matrix-based propagation using sparse matrices on GPU
+ for iteration in range(1, self.config.max_iterations):
+ # Convert sparse tensor to dense for threshold operation
+ current_entity_scores_dense = current_entity_scores_sparse.to_dense()
+
+ # Apply threshold to current scores
+ current_entity_scores_dense = torch.where(
+ current_entity_scores_dense >= self.config.iteration_threshold,
+ current_entity_scores_dense,
+ torch.zeros_like(current_entity_scores_dense)
+ )
+
+ # Get non-zero indices for sparse representation
+ nonzero_mask = current_entity_scores_dense > 0
+ nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).squeeze(-1)
+
+ if len(nonzero_indices) == 0:
+ break
+
+ # Extract non-zero values and create sparse tensor
+ nonzero_values = current_entity_scores_dense[nonzero_indices]
+ current_entity_scores_sparse = torch.sparse_coo_tensor(
+ nonzero_indices.unsqueeze(0), nonzero_values, (num_entities,), device=self.device
+ ).coalesce()
+
+ # Step 1: Sparse entity scores @ Sparse E2S matrix
+ # Convert sparse vector to 2D for matrix multiplication
+ current_scores_2d = torch.sparse_coo_tensor(
+ torch.stack([nonzero_indices, torch.zeros_like(nonzero_indices)]),
+ nonzero_values,
+ (num_entities, 1),
+ device=self.device
+ ).coalesce()
+
+ # E @ E2S -> sentence activation scores (sparse @ sparse = dense)
+ sentence_activation = torch.sparse.mm(
+ self.entity_to_sentence_sparse.t(),
+ current_scores_2d
+ )
+ # Convert to dense before squeeze to avoid CUDA sparse tensor issues
+ if sentence_activation.is_sparse:
+ sentence_activation = sentence_activation.to_dense()
+ sentence_activation = sentence_activation.squeeze()
+
+ # Apply sentence deduplication: mask out used sentences
+ sentence_activation = torch.where(
+ used_sentence_mask,
+ torch.zeros_like(sentence_activation),
+ sentence_activation
+ )
+
+ # Step 2: Apply sentence similarities (element-wise on dense vector)
+ weighted_sentence_scores = sentence_activation * sentence_similarities
+
+ # Implement per-entity top-k sentence selection (more aggressive pruning)
+ # For vectorized efficiency, we use a tighter global approximation
+ num_active = len(nonzero_indices)
+ if num_active > 0 and self.config.top_k_sentence > 0:
+ # Calculate adaptive k based on number of active entities
+ # Use per-entity top-k approximation: num_active * top_k_sentence
+ k = min(int(num_active * self.config.top_k_sentence), len(weighted_sentence_scores))
+ if k > 0:
+ # Get top-k sentences
+ top_k_values, top_k_indices = torch.topk(weighted_sentence_scores, k)
+ # Zero out all non-top-k sentences
+ mask = torch.zeros_like(weighted_sentence_scores, dtype=torch.bool)
+ mask[top_k_indices] = True
+ weighted_sentence_scores = torch.where(
+ mask,
+ weighted_sentence_scores,
+ torch.zeros_like(weighted_sentence_scores)
+ )
+
+ # Mark these sentences as used for deduplication
+ used_sentence_mask[top_k_indices] = True
+
+ # Step 3: Weighted sentences @ S2E -> propagate to next entities
+ # Convert to sparse for more efficient computation
+ weighted_nonzero_mask = weighted_sentence_scores > 0
+ weighted_nonzero_indices = torch.nonzero(weighted_nonzero_mask, as_tuple=False).squeeze(-1)
+
+ if len(weighted_nonzero_indices) > 0:
+ weighted_nonzero_values = weighted_sentence_scores[weighted_nonzero_indices]
+ weighted_scores_2d = torch.sparse_coo_tensor(
+ torch.stack([weighted_nonzero_indices, torch.zeros_like(weighted_nonzero_indices)]),
+ weighted_nonzero_values,
+ (num_sentences, 1),
+ device=self.device
+ ).coalesce()
+
+ next_entity_scores_result = torch.sparse.mm(
+ self.sentence_to_entity_sparse.t(),
+ weighted_scores_2d
+ )
+ # Convert to dense before squeeze to avoid CUDA sparse tensor issues
+ if next_entity_scores_result.is_sparse:
+ next_entity_scores_result = next_entity_scores_result.to_dense()
+ next_entity_scores_dense = next_entity_scores_result.squeeze()
+ else:
+ next_entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device)
+
+ # Update entity scores (accumulate in dense format)
+ entity_scores_dense += next_entity_scores_dense
+
+ # Update actived_entities dictionary (only for entities above threshold)
+ next_entity_scores_np = next_entity_scores_dense.cpu().numpy()
+ active_indices = np.where(next_entity_scores_np >= self.config.iteration_threshold)[0]
+ for entity_idx in active_indices:
+ score = next_entity_scores_np[entity_idx]
+ entity_hash_id = self.entity_hash_ids[entity_idx]
+ if entity_hash_id not in actived_entities or actived_entities[entity_hash_id][1] < score:
+ actived_entities[entity_hash_id] = (entity_idx, float(score), iteration)
+
+ # Prepare sparse tensor for next iteration
+ next_nonzero_mask = next_entity_scores_dense > 0
+ next_nonzero_indices = torch.nonzero(next_nonzero_mask, as_tuple=False).squeeze(-1)
+ if len(next_nonzero_indices) > 0:
+ next_nonzero_values = next_entity_scores_dense[next_nonzero_indices]
+ current_entity_scores_sparse = torch.sparse_coo_tensor(
+ next_nonzero_indices.unsqueeze(0), next_nonzero_values,
+ (num_entities,), device=self.device
+ ).coalesce()
+ else:
+ break
+
+ # Convert back to numpy for final processing
+ entity_scores_final = entity_scores_dense.cpu().numpy()
+
+ # Map entity scores to graph node weights (only for non-zero scores)
+ nonzero_indices = np.where(entity_scores_final > 0)[0]
+ for entity_idx in nonzero_indices:
+ score = entity_scores_final[entity_idx]
+ entity_hash_id = self.entity_hash_ids[entity_idx]
+ entity_node_idx = self.node_name_to_vertex_idx[entity_hash_id]
+ entity_weights[entity_node_idx] = float(score)
+
+ return entity_weights, actived_entities
+
def calculate_passage_scores(self, question_embedding, actived_entities):
passage_weights = np.zeros(len(self.graph.vs["name"]))
dpr_passage_indices, dpr_passage_scores = self.dense_passage_retrieval(question_embedding)
@@ -305,8 +569,6 @@ class LinearRAG:
self.graph.add_edges(edges)
self.graph.es['weight'] = weights
-
-
def add_entity_to_passage_edges(self, passage_hash_id_to_entities):
passage_to_entity_count ={}
passage_to_all_score = defaultdict(int)
diff --git a/src/config.py b/src/config.py
index 3ef136c..54677f3 100644
--- a/src/config.py
+++ b/src/config.py
@@ -17,4 +17,5 @@ class LinearRAGConfig:
passage_ratio: float = 1.5
passage_node_weight: float = 0.05
damping: float = 0.5
- iteration_threshold: float = 0.5 \ No newline at end of file
+ iteration_threshold: float = 0.5
+ use_vectorized_retrieval: bool = False # True for vectorized matrix computation, False for BFS iteration \ No newline at end of file
diff --git a/src/evaluate.py b/src/evaluate.py
index 32e652c..d814599 100644
--- a/src/evaluate.py
+++ b/src/evaluate.py
@@ -48,6 +48,7 @@ class Evaluator:
return 1
else:
return 0
+
def evaluate_sig_sample(self,idx,prediction):
pre_answer = prediction["pred_answer"]
gold_ans = prediction["gold_answer"]
diff --git a/src/ner.py b/src/ner.py
index 2ca4afb..4ee788e 100644
--- a/src/ner.py
+++ b/src/ner.py
@@ -1,5 +1,7 @@
import spacy
from collections import defaultdict
+import pdb
+
class SpacyNER:
def __init__(self,spacy_model):
@@ -25,6 +27,7 @@ class SpacyNER:
sentence_to_entities = defaultdict(list)
unique_entities = set()
passage_hash_id_to_entities = {}
+ pdb.set_trace()
for ent in doc.ents:
if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
continue