Skip to content

Commit

Permalink
Merge pull request #867 from trheyi/main
Browse files Browse the repository at this point in the history
Add retry and silent modes to assistant chat processing
  • Loading branch information
trheyi authored Feb 18, 2025
2 parents c5eade1 + 13c4e65 commit 09a38d8
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 23 deletions.
35 changes: 30 additions & 5 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/gin-gonic/gin"
jsoniter "github.com/json-iterator/go"
"github.com/yaoapp/gou/fs"
"github.com/yaoapp/kun/utils"
"github.com/yaoapp/kun/log"
chatctx "github.com/yaoapp/yao/neo/context"
chatMessage "github.com/yaoapp/yao/neo/message"
)
Expand Down Expand Up @@ -190,6 +190,14 @@ func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context, contents *c
return fmt.Errorf("input is required")
}

// Retry mode
retry := false
_, has = next.Payload["retry"]
if has {
retry = next.Payload["retry"].(bool)
ctx.Retry = retry
}

switch v := next.Payload["input"].(type) {
case string:
messages := chatMessage.Message{}
Expand Down Expand Up @@ -338,6 +346,10 @@ func (ast *Assistant) streamChat(
return 1 // continue
}

// Retry mode
msg.Retry = ctx.Retry // Retry mode
msg.Silent = ctx.Silent // Silent mode

// Handle error
if msg.Type == "error" {
value := msg.String()
Expand All @@ -348,7 +360,10 @@ func (ast *Assistant) streamChat(
value = res.Error
}
}
chatMessage.New().Error(value).Done().Write(c.Writer)
newMsg := chatMessage.New().Error(value).Done()
newMsg.Retry = ctx.Retry
newMsg.Silent = ctx.Silent
newMsg.Write(c.Writer)
return 0 // break
}

Expand Down Expand Up @@ -468,6 +483,9 @@ func (ast *Assistant) streamChat(
"delta": true,
})

output.Retry = ctx.Retry // Retry mode
output.Silent = ctx.Silent // Silent mode

if isFirst {
output.Assistant(ast.ID, ast.Name, ast.Avatar)
isFirst = false
Expand All @@ -489,6 +507,8 @@ func (ast *Assistant) streamChat(
"type": "text",
"delta": true,
"done": true,
"retry": ctx.Retry,
"silent": ctx.Silent,
}).
Write(c.Writer)
}
Expand Down Expand Up @@ -521,6 +541,8 @@ func (ast *Assistant) streamChat(
output := chatMessage.New().Done()
if res != nil && res.Output != nil {
output = chatMessage.New().Map(map[string]interface{}{"text": res.Output, "done": true})
output.Retry = ctx.Retry
output.Silent = ctx.Silent
}
output.Write(c.Writer)
done <- true
Expand All @@ -542,6 +564,8 @@ func (ast *Assistant) streamChat(
if err != nil {
return fmt.Errorf("error: %s", err.Error())
}
msg.Retry = ctx.Retry
msg.Silent = ctx.Silent
msg.Done().Write(c.Writer)
}

Expand Down Expand Up @@ -826,9 +850,10 @@ func (ast *Assistant) requestMessages(ctx context.Context, messages []chatMessag

// For debug environment, print the request messages
if os.Getenv("YAO_AGENT_PRINT_REQUEST_MESSAGES") == "true" {
fmt.Println("--- REQUEST_MESSAGES -----------------------------")
utils.Dump(newMessages)
fmt.Println("--- END REQUEST_MESSAGES -----------------------------")
for _, message := range newMessages {
raw, _ := jsoniter.MarshalToString(message)
log.Trace("[Request Message] %s", raw)
}
}

return newMessages, nil
Expand Down
94 changes: 78 additions & 16 deletions neo/assistant/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/google/uuid"
jsoniter "github.com/json-iterator/go"
"github.com/yaoapp/gou/runtime/v8/bridge"
"github.com/yaoapp/kun/log"
chatctx "github.com/yaoapp/yao/neo/context"
"github.com/yaoapp/yao/neo/message"
chatMessage "github.com/yaoapp/yao/neo/message"
Expand Down Expand Up @@ -168,9 +170,7 @@ func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []
text = content[:endIndex]
text = strings.TrimSpace(text)
if os.Getenv("YAO_AGENT_PRINT_TOOL_CALL") == "true" {
fmt.Println("---- EXTRACTED TOOL CALL ----")
fmt.Println(text)
fmt.Println("---- END EXTRACTED TOOL CALL ----")
log.Trace("[TOOL CALL] %s", text)
}
}
}
Expand Down Expand Up @@ -309,7 +309,81 @@ func (ast *Assistant) call(ctx context.Context, method string, c *gin.Context, c
defer scriptCtx.Close()

// Add sendMessage function to the script context
scriptCtx.WithFunction("SendMessage", func(info *v8go.FunctionCallbackInfo) *v8go.Value {
scriptCtx.WithFunction("SendMessage", sendMessage(c, contents))
scriptCtx.WithFunction("Run", run(c, context))

// Check if the method exists
if !scriptCtx.Global().Has(method) {
return nil, fmt.Errorf(HookErrorMethodNotFound)
}

// Call the method directly in the current thread
args = append([]interface{}{context.Map()}, args...)
if scriptCtx != nil {
return scriptCtx.CallWith(ctx, method, args...)
}
return nil, nil
}

// Execute the assistant
func run(c *gin.Context, context chatctx.Context) func(info *v8go.FunctionCallbackInfo) *v8go.Value {
return func(info *v8go.FunctionCallbackInfo) *v8go.Value {

// Get the args
args := info.Args()
if len(args) < 2 {
return bridge.JsException(info.Context(), "Run requires at least two arguments")
}

// Get the assistant id
assistantID := args[0].String()

// Get the assistant
assistant, err := Get(assistantID)
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}

// input []chatMessage.Message
input := args[1].String()

options := map[string]interface{}{}
if len(args) > 2 {
optionsRaw, err := bridge.GoValue(args[2], info.Context())
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}

// Parse the options
if optionsRaw != nil {
switch v := optionsRaw.(type) {
case string:
err := jsoniter.UnmarshalFromString(v, &options)
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}
case map[string]interface{}:
options = v
default:
return bridge.JsException(info.Context(), "Invalid options")
}
}
}

// Execute the assistant
context.AssistantID = assistantID
context.ChatID = fmt.Sprintf("chat_%s", uuid.New().String()) // New chat id
context.Silent = true // Silent mode
err = assistant.Execute(c, context, input, options) // Execute the assistant
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}
return nil
}
}

func sendMessage(c *gin.Context, contents *chatMessage.Contents) func(info *v8go.FunctionCallbackInfo) *v8go.Value {
return func(info *v8go.FunctionCallbackInfo) *v8go.Value {

// Get the message
args := info.Args()
Expand Down Expand Up @@ -354,17 +428,5 @@ func (ast *Assistant) call(ctx context.Context, method string, c *gin.Context, c
default:
return bridge.JsException(info.Context(), "SendMessage requires a string or a map")
}
})

// Check if the method exists
if !scriptCtx.Global().Has(method) {
return nil, fmt.Errorf(HookErrorMethodNotFound)
}

// Call the method directly in the current thread
args = append([]interface{}{context.Map()}, args...)
if scriptCtx != nil {
return scriptCtx.CallWith(ctx, method, args...)
}
return nil, nil
}
2 changes: 2 additions & 0 deletions neo/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type Context struct {
Namespace string `json:"namespace,omitempty"`
Config map[string]interface{} `json:"config,omitempty"`
Signal interface{} `json:"signal,omitempty"`
Silent bool `json:"silent,omitempty"` // Silent mode
Retry bool `json:"retry,omitempty"` // Retry mode
Upload *FileUpload `json:"upload,omitempty"`
Version bool `json:"version,omitempty"` // Version support
RAG bool `json:"rag,omitempty"` // RAG support
Expand Down
5 changes: 3 additions & 2 deletions neo/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Message struct {
Data map[string]interface{} `json:"-"` // data for the message
Pending bool `json:"-"` // pending for the message
Hidden bool `json:"hidden,omitempty"` // hidden for the message (not show in the UI and history)
Retry bool `json:"retry,omitempty"` // retry for the message
Silent bool `json:"silent,omitempty"` // silent for the message (not show in the UI and history)
}

// Mention represents a mention
Expand Down Expand Up @@ -187,10 +189,9 @@ func NewAny(content interface{}) (*Message, error) {
// NewOpenAI create a new message from OpenAI response
func NewOpenAI(data []byte, isThinking bool) *Message {

// For Debug
// For debug environment, print the response data
if os.Getenv("YAO_AGENT_PRINT_RESPONSE_DATA") == "true" {
fmt.Printf("%s\n", string(data))
log.Trace("[Response Data] %s", string(data))
}

if data == nil || len(data) == 0 {
Expand Down

0 comments on commit 09a38d8

Please sign in to comment.