-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathoutput_parser.py
177 lines (152 loc) · 5.11 KB
/
output_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import ast
import re
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from typing_extensions import TypedDict
THOUGHT_PATTERN = r"Thought: ([^\n]*)"
ACTION_PATTERN = r"\n*(\d+)\. (\w+)\((.*)\)(\s*#\w+\n)?"
# $1 or ${1} -> 1
ID_PATTERN = r"\$\{?(\d+)\}?"
END_OF_PLAN = ""
### Helper functions
def _ast_parse(arg: str) -> Any:
try:
return ast.literal_eval(arg)
except: # noqa
return arg
def _parse_llm_compiler_action_args(args: str, tool: Union[str, BaseTool]) -> list[Any]:
"""Parse arguments from a string."""
if args == "":
return ()
if isinstance(tool, str):
return ()
extracted_args = {}
tool_key = None
prev_idx = None
for key in tool.args.keys():
# Split if present
if f"{key}=" in args:
idx = args.index(f"{key}=")
if prev_idx is not None:
extracted_args[tool_key] = _ast_parse(
args[prev_idx:idx].strip().rstrip(",")
)
args = args.split(f"{key}=", 1)[1]
tool_key = key
prev_idx = 0
if prev_idx is not None:
extracted_args[tool_key] = _ast_parse(
args[prev_idx:].strip().rstrip(",").rstrip(")")
)
return extracted_args
def default_dependency_rule(idx, args: str):
matches = re.findall(ID_PATTERN, args)
numbers = [int(match) for match in matches]
return idx in numbers
def _get_dependencies_from_graph(
idx: int, tool_name: str, args: Dict[str, Any]
) -> dict[str, list[str]]:
"""Get dependencies from a graph."""
if tool_name == "join":
return list(range(1, idx))
return [i for i in range(1, idx) if default_dependency_rule(i, str(args))]
class Task(TypedDict):
idx: int
tool: BaseTool
args: list
dependencies: Dict[str, list]
thought: Optional[str]
def instantiate_task(
tools: Sequence[BaseTool],
idx: int,
tool_name: str,
args: Union[str, Any],
thought: Optional[str] = None,
) -> Task:
if tool_name == "join":
tool = "join"
else:
try:
tool = tools[[tool.name for tool in tools].index(tool_name)]
except ValueError as e:
raise OutputParserException(f"Tool {tool_name} not found.") from e
tool_args = _parse_llm_compiler_action_args(args, tool)
dependencies = _get_dependencies_from_graph(idx, tool_name, tool_args)
return Task(
idx=idx,
tool=tool,
args=tool_args,
dependencies=dependencies,
thought=thought,
)
class LLMCompilerPlanParser(BaseTransformOutputParser[dict], extra="allow"):
"""Planning output parser."""
tools: List[BaseTool]
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Task]:
texts = []
# TODO: Cleanup tuple state tracking here.
thought = None
for chunk in input:
# Assume input is str. TODO: support vision/other formats
text = chunk if isinstance(chunk, str) else str(chunk.content)
for task, thought in self.ingest_token(text, texts, thought):
yield task
# Final possible task
if texts:
task, _ = self._parse_task("".join(texts), thought)
if task:
yield task
def parse(self, text: str) -> List[Task]:
return list(self._transform([text]))
def stream(
self,
input: str | BaseMessage,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Task]:
yield from self.transform([input], config, **kwargs)
def ingest_token(
self, token: str, buffer: List[str], thought: Optional[str]
) -> Iterator[Tuple[Optional[Task], str]]:
buffer.append(token)
if "\n" in token:
buffer_ = "".join(buffer).split("\n")
suffix = buffer_[-1]
for line in buffer_[:-1]:
task, thought = self._parse_task(line, thought)
if task:
yield task, thought
buffer.clear()
buffer.append(suffix)
def _parse_task(self, line: str, thought: Optional[str] = None):
task = None
if match := re.match(THOUGHT_PATTERN, line):
# Optionally, action can be preceded by a thought
thought = match.group(1)
elif match := re.match(ACTION_PATTERN, line):
# if action is parsed, return the task, and clear the buffer
idx, tool_name, args, _ = match.groups()
idx = int(idx)
task = instantiate_task(
tools=self.tools,
idx=idx,
tool_name=tool_name,
args=args,
thought=thought,
)
thought = None
# Else it is just dropped
return task, thought