Skip to content

Commit

Permalink
Add Exa AI Search Support
Browse files Browse the repository at this point in the history
  • Loading branch information
pralberteinstein committed Jan 9, 2025
1 parent 5a61eec commit 06ab067
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,83 @@ def forward(
logging.error(f"Error occurs when searching query {query}: {e}")

return collected_results

class ExaSearch(dspy.Retrieve):
def __init__(self, k=3, exa_api_key=None):
super().__init__(k=k)
try:
from exa_py import Exa
except ImportError as err:
raise ImportError("ExaSearch requires `pip install exa_py`.") from err
if not exa_api_key and not os.environ.get("EXA_API_KEY"):
raise RuntimeError(
"You must supply exa_api_key or set environment variable EXA_API_KEY"
)
elif exa_api_key:
self.exa_api_key = exa_api_key
self.exa = Exa(api_key = exa_api_key)
else:
self.exa_api_key = os.environ["EXA_API_KEY"]
self.exa = Exa(os.getenv(exa_api_key))

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {"ExaSearch": usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
"""Search with Exa Search API for self.k top passages for query or queries.
Args:
query_or_queries (Union[str, List[str]]): _description_
exclude_urls (List[str], optional): _description_. Defaults to [].
Returns:
A list of dicts, each dict has keys of 'title', 'url', 'summary', 'score'.
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []

for query in queries:
try:
response = self.exa.search_and_contents(query, type="auto", summary=True, numresults = self.k)

for item in response.results:
if item.url not in exclude_urls:
result = {
'title': item.title,
'url': item.url,
'summary': item.summary,
'score': item.score
}
collected_results.append(result)

while len(collected_results) < self.k and len(response.results) == self.k:
additional_results = self.exa.search_and_contents(
query,
type="auto",
summary=True,
num_results=self.k,
exclude_domains=[result['url'].split('/')[2] for result in collected_results]
)
for item in additional_results.results:
if item.url not in exclude_urls and len(collected_results) < self.k:
result = {
'title': item.title,
'author': item.author,
'url': item.url,
'summary': item.summary,
'score': item.score
}
collected_results.append(result)

except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")

return collected_results
pass

0 comments on commit 06ab067

Please sign in to comment.