Skip to content

Commit

Permalink
Add PGVector retriever support
Browse files Browse the repository at this point in the history
Related to stanford-oval#167

Add support for PGVector and GraphRAG retrievals.

* **knowledge_storm/rm.py**:
  - Add `PGVectorRetriever` class to support PGVector retrieval.
  - Add `GraphRAGRM` class to support GraphRAG retrieval.
  - Update import statements to include necessary libraries for PGVector and GraphRAG.

* **examples/storm_examples/run_storm_wiki_claude.py** and **examples/storm_examples/run_storm_wiki_gemini.py**:
  - Update environment variable descriptions to include PGVector and GraphRAG.
  - Add cases for `pgvector` and `graphrag` retrievers in the main function.

* **knowledge_storm/collaborative_storm/engine.py**:
  - Add `pgvector_retriever` attribute to `CollaborativeStormLMConfigs` class.
  - Add `set_pgvector_retriever` method to `CollaborativeStormLMConfigs` class.
  - Update import statements to include `PGVectorRetriever`.

* **tests/test_pgvector_retriever.py**:
  - Add unit tests for the new `PGVectorRetriever` class.
  • Loading branch information
vishwamartur committed Jan 3, 2025
1 parent 5a61eec commit c31981e
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 10 deletions.
14 changes: 9 additions & 5 deletions examples/storm_examples/run_storm_wiki_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
STORM Wiki pipeline powered by Claude family models and You.com search engine.
You need to set up the following environment variables to run this script:
- ANTHROPIC_API_KEY: Anthropic API key
- YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
- YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, TAVILY_API_KEY: Tavily API key, PGVECTOR_DB_URL: PGVector database URL, GRAPH_RAG_DATA: GraphRAG data in JSON format
Output will be structured as below
args.output_dir/
Expand All @@ -21,7 +21,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import ClaudeModel
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, PGVectorRM, GraphRAGRM
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -76,8 +76,12 @@ def main(args):
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True)
case 'searxng':
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k)
case 'pgvector':
rm = PGVectorRM(db_url=os.getenv('PGVECTOR_DB_URL'), table_name='documents', embedding_model='BAAI/bge-m3', k=engine_args.search_top_k)
case 'graphrag':
rm = GraphRAGRM(graph_data=os.getenv('GRAPH_RAG_DATA'), embedding_model='BAAI/bge-m3', k=engine_args.search_top_k)
case _:
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", "pgvector", or "graphrag"')

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -102,7 +106,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'pgvector', 'graphrag'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand All @@ -127,4 +131,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
12 changes: 8 additions & 4 deletions examples/storm_examples/run_storm_wiki_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
STORM Wiki pipeline powered by Google Gemini models and search engine.
You need to set up the following environment variables to run this script:
- GOOGLE_API_KEY: Google API key (Can be obtained from https://ai.google.dev/gemini-api/docs/api-key)
- YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, or TAVILY_API_KEY: Tavily API key
- YDC_API_KEY: You.com API key; BING_SEARCH_API_KEY: Bing Search API key, SERPER_API_KEY: Serper API key, BRAVE_API_KEY: Brave API key, TAVILY_API_KEY: Tavily API key, PGVECTOR_DB_URL: PGVector database URL, GRAPH_RAG_DATA: GraphRAG data in JSON format
Output will be structured as below
args.output_dir/
Expand All @@ -21,7 +21,7 @@

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import GoogleModel
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, PGVectorRM, GraphRAGRM
from knowledge_storm.utils import load_api_key

def main(args):
Expand Down Expand Up @@ -77,8 +77,12 @@ def main(args):
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True)
case 'searxng':
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k)
case 'pgvector':
rm = PGVectorRM(db_url=os.getenv('PGVECTOR_DB_URL'), table_name='documents', embedding_model='BAAI/bge-m3', k=engine_args.search_top_k)
case 'graphrag':
rm = GraphRAGRM(graph_data=os.getenv('GRAPH_RAG_DATA'), embedding_model='BAAI/bge-m3', k=engine_args.search_top_k)
case _:
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", "pgvector", or "graphrag"')

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -103,7 +107,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'pgvector', 'graphrag'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand Down
6 changes: 5 additions & 1 deletion knowledge_storm/collaborative_storm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..interface import LMConfigs, Agent
from ..logging_wrapper import LoggingWrapper
from ..lm import OpenAIModel, AzureOpenAIModel, TogetherClient
from ..rm import BingSearch
from ..rm import BingSearch, PGVectorRetriever


class CollaborativeStormLMConfigs(LMConfigs):
Expand All @@ -35,6 +35,7 @@ def __init__(self):
self.warmstart_outline_gen_lm = None
self.question_asking_lm = None
self.knowledge_base_lm = None
self.pgvector_retriever = None # Peebb

def init(
self,
Expand Down Expand Up @@ -159,6 +160,9 @@ def set_question_asking_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
def set_knowledge_base_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
self.knowledge_base_lm = model

def set_pgvector_retriever(self, retriever: PGVectorRetriever): # Peebb
self.pgvector_retriever = retriever # Peebb

def collect_and_reset_lm_usage(self):
lm_usage = {}
for attr_name in self.__dict__:
Expand Down
188 changes: 188 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@

from .utils import WebPageHelper

from pgvector.sqlalchemy import Vector
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.orm import sessionmaker, declarative_base

import networkx as nx
from networkx.readwrite import json_graph


class YouRM(dspy.Retrieve):
def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None):
Expand Down Expand Up @@ -333,6 +340,187 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st
return collected_results


class PGVectorRetriever(dspy.Retrieve):
"""Retrieve information from custom documents using PGVector."""

def __init__(
self,
db_url: str,
table_name: str,
embedding_model: str,
k: int = 3,
):
"""
Params:
db_url: Database URL for the PostgreSQL database with PGVector extension.
table_name: Name of the table containing the vectors.
embedding_model: Name of the Hugging Face embedding model.
k: Number of top chunks to retrieve.
"""
super().__init__(k=k)
self.usage = 0
self.db_url = db_url
self.table_name = table_name

if not db_url:
raise ValueError("Please provide a database URL.")
if not table_name:
raise ValueError("Please provide a table name.")
if not embedding_model:
raise ValueError("Please provide an embedding model.")

model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
self.model = HuggingFaceEmbeddings(
model_name=embedding_model,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)

self.engine = create_engine(self.db_url)
self.Session = sessionmaker(bind=self.engine)
self.Base = declarative_base()

class VectorTable(self.Base):
__tablename__ = self.table_name
id = Column(Integer, primary_key=True)
content = Column(String)
title = Column(String)
url = Column(String)
description = Column(String)
vector = Column(Vector(1536))

self.VectorTable = VectorTable

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {"PGVectorRetriever": usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""
Search in your data for self.k top passages for query or queries.
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []

session = self.Session()
for query in queries:
query_vector = self.model.embed_query(query)
results = (
session.query(self.VectorTable)
.order_by(self.VectorTable.vector.l2_distance(query_vector))
.limit(self.k)
.all()
)
for result in results:
collected_results.append(
{
"description": result.description,
"snippets": [result.content],
"title": result.title,
"url": result.url,
}
)
session.close()

return collected_results


class GraphRAGRM(dspy.Retrieve):
"""Retrieve information from custom documents using GraphRAG."""

def __init__(
self,
graph_data: dict,
embedding_model: str,
k: int = 3,
):
"""
Params:
graph_data: Graph data in JSON format.
embedding_model: Name of the Hugging Face embedding model.
k: Number of top chunks to retrieve.
"""
super().__init__(k=k)
self.usage = 0
self.graph = json_graph.node_link_graph(graph_data)

if not embedding_model:
raise ValueError("Please provide an embedding model.")

model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
self.model = HuggingFaceEmbeddings(
model_name=embedding_model,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {"GraphRAGRM": usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""
Search in your data for self.k top passages for query or queries.
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []

for query in queries:
query_vector = self.model.embed_query(query)
node_scores = {}
for node, data in self.graph.nodes(data=True):
node_vector = data.get("vector")
if node_vector is not None:
score = sum(
(a - b) ** 2 for a, b in zip(query_vector, node_vector)
) # Euclidean distance
node_scores[node] = score

top_nodes = sorted(node_scores, key=node_scores.get)[: self.k]
for node in top_nodes:
data = self.graph.nodes[node]
collected_results.append(
{
"description": data.get("description", ""),
"snippets": [data.get("content", "")],
"title": data.get("title", ""),
"url": data.get("url", ""),
}
)

return collected_results


class StanfordOvalArxivRM(dspy.Retrieve):
"""[Alpha] This retrieval class is for internal use only, not intended for the public."""

Expand Down
44 changes: 44 additions & 0 deletions tests/test_pgvector_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest
from unittest.mock import patch, MagicMock
from knowledge_storm.rm import PGVectorRetriever

class TestPGVectorRetriever(unittest.TestCase):

@patch('knowledge_storm.rm.create_engine')
@patch('knowledge_storm.rm.HuggingFaceEmbeddings')
def setUp(self, MockHuggingFaceEmbeddings, MockCreateEngine):
self.mock_model = MockHuggingFaceEmbeddings.return_value
self.mock_engine = MockCreateEngine.return_value
self.retriever = PGVectorRetriever(
db_url='postgresql://user:password@localhost/dbname',
table_name='documents',
embedding_model='BAAI/bge-m3',
k=3
)

@patch('knowledge_storm.rm.PGVectorRetriever.Session')
def test_forward(self, MockSession):
mock_session = MockSession.return_value
mock_query = MagicMock()
mock_session.query.return_value.order_by.return_value.limit.return_value.all.return_value = [
MagicMock(description='desc1', content='content1', title='title1', url='url1'),
MagicMock(description='desc2', content='content2', title='title2', url='url2')
]
self.retriever.model.embed_query.return_value = [0.1, 0.2, 0.3]

results = self.retriever.forward('test query', [])

self.assertEqual(len(results), 2)
self.assertEqual(results[0]['description'], 'desc1')
self.assertEqual(results[0]['snippets'], ['content1'])
self.assertEqual(results[0]['title'], 'title1')
self.assertEqual(results[0]['url'], 'url1')

def test_get_usage_and_reset(self):
self.retriever.usage = 5
usage = self.retriever.get_usage_and_reset()
self.assertEqual(usage, {'PGVectorRetriever': 5})
self.assertEqual(self.retriever.usage, 0)

if __name__ == '__main__':
unittest.main()

0 comments on commit c31981e

Please sign in to comment.