Skip to content

Commit

Permalink
fix: filter and context selection (#672)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Feb 28, 2024
1 parent 6aa4b5b commit 366cb9f
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/ragas/testset/evolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def generate_datarow(
assert self.generator_llm is not None, "generator_llm cannot be None"

node_content = [
f"{i}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
]
results = await self.generator_llm.generate(
prompt=self.find_relevent_context_prompt.format(
Expand All @@ -197,15 +197,20 @@ async def generate_datarow(
if isinstance(relevent_contexts_result, dict)
else None
)

if relevant_context_indices is None:
relevant_context = CurrentNodes(
root_node=current_nodes.root_node, nodes=current_nodes.nodes
)
else:
selected_nodes = [current_nodes.nodes[i] for i in relevant_context_indices]
relevant_context = CurrentNodes(
root_node=selected_nodes[0], nodes=selected_nodes
selected_nodes = [
current_nodes.nodes[i - 1]
for i in relevant_context_indices
if i - 1 < len(current_nodes.nodes)
]
relevant_context = (
CurrentNodes(root_node=selected_nodes[0], nodes=selected_nodes)
if selected_nodes
else current_nodes
)

merged_nodes = self.merge_nodes(relevant_context)
Expand Down Expand Up @@ -278,10 +283,9 @@ async def _aevolve(
merged_node = self.merge_nodes(current_nodes)
passed = await self.node_filter.filter(merged_node)
if not passed["score"]:
nodes = self.docstore.get_random_nodes(k=1)
new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes)
current_nodes = self._get_new_random_node()
return await self.aretry_evolve(
current_tries, new_current_nodes, update_count=False
current_tries, current_nodes, update_count=False
)

logger.debug("keyphrases in merged node: %s", merged_node.keyphrases)
Expand Down Expand Up @@ -400,7 +404,7 @@ async def _acomplex_evolution(
)

assert self.evolution_filter is not None, "evolution filter cannot be None"
if not await self.evolution_filter.filter(simple_question, compressed_question):
if await self.evolution_filter.filter(simple_question, compressed_question):
# retry
current_nodes = self.se._get_new_random_node()
logger.debug(
Expand Down Expand Up @@ -500,7 +504,7 @@ async def _aevolve(
)

assert self.evolution_filter is not None, "evolution filter cannot be None"
if not await self.evolution_filter.filter(simple_question, compressed_question):
if await self.evolution_filter.filter(simple_question, compressed_question):
# retry
current_nodes = self.se._get_new_random_node()
return await self.aretry_evolve(current_tries, current_nodes)
Expand Down

0 comments on commit 366cb9f

Please sign in to comment.