Skip to content

Commit

Permalink
Merge pull request #868 from trheyi/main
Browse files Browse the repository at this point in the history
Add callback support for assistant chat processing
  • Loading branch information
trheyi authored Feb 18, 2025
2 parents 09a38d8 + 979acb2 commit adc3d01
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 22 deletions.
42 changes: 27 additions & 15 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ func GetByConnector(connector string, name string) (*Assistant, error) {
}

// Execute implements the execute functionality
func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}) error {
func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}, callback ...interface{}) error {
contents := chatMessage.NewContents()
messages, err := ast.withHistory(ctx, input)
if err != nil {
return err
}
return ast.execute(c, ctx, messages, options, contents)
return ast.execute(c, ctx, messages, options, contents, callback...)
}

// Execute implements the execute functionality
func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatMessage.Message, userOptions map[string]interface{}, contents *chatMessage.Contents) error {
func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatMessage.Message, userOptions map[string]interface{}, contents *chatMessage.Contents, callback ...interface{}) error {

if contents == nil {
contents = chatMessage.NewContents()
Expand Down Expand Up @@ -123,11 +123,11 @@ func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatM

// Update assistant id
ctx.AssistantID = res.AssistantID
return newAst.handleChatStream(c, ctx, input, options, contents)
return newAst.handleChatStream(c, ctx, input, options, contents, callback...)
}

// Only proceed with chat stream if no specific next action was handled
return ast.handleChatStream(c, ctx, input, options, contents)
return ast.handleChatStream(c, ctx, input, options, contents, callback...)
}

// Execute the next action
Expand Down Expand Up @@ -289,13 +289,13 @@ func (ast *Assistant) Call(c *gin.Context, payload APIPayload) (interface{}, err
}

// handleChatStream manages the streaming chat interaction with the AI
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}, contents *chatMessage.Contents) error {
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}, contents *chatMessage.Contents, callback ...interface{}) error {
clientBreak := make(chan bool, 1)
done := make(chan bool, 1)

// Chat with AI in background
go func() {
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, contents)
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, contents, callback...)
if err != nil {
chatMessage.New().Error(err).Done().Write(c.Writer)
}
Expand All @@ -320,7 +320,14 @@ func (ast *Assistant) streamChat(
options map[string]interface{},
clientBreak chan bool,
done chan bool,
contents *chatMessage.Contents) error {
contents *chatMessage.Contents,
callback ...interface{},
) error {

var cb interface{}
if len(callback) > 0 {
cb = callback[0]
}

errorRaw := ""
isFirst := true
Expand Down Expand Up @@ -363,7 +370,7 @@ func (ast *Assistant) streamChat(
newMsg := chatMessage.New().Error(value).Done()
newMsg.Retry = ctx.Retry
newMsg.Silent = ctx.Silent
newMsg.Write(c.Writer)
newMsg.Callback(cb).Write(c.Writer)
return 0 // break
}

Expand All @@ -380,8 +387,11 @@ func (ast *Assistant) streamChat(
if isThinking && msg.Type != "think" {
// add the think close tag
end := chatMessage.New().Map(map[string]interface{}{"text": "\n</think>\n", "type": "think", "delta": true})
end.Write(c.Writer)
end.ID = currentMessageID
end.Retry = ctx.Retry
end.Silent = ctx.Silent

end.Callback(cb).Write(c.Writer)
end.AppendTo(contents)
contents.UpdateType("think", map[string]interface{}{"text": contents.Text()}, currentMessageID)
isThinking = false
Expand All @@ -405,8 +415,10 @@ func (ast *Assistant) streamChat(

if msg.IsDone {
end := chatMessage.New().Map(map[string]interface{}{"text": "}\n</tool>\n", "type": "tool", "delta": true})
end.Write(c.Writer)
end.ID = currentMessageID
end.Retry = ctx.Retry
end.Silent = ctx.Silent
end.Callback(cb).Write(c.Writer)
end.AppendTo(contents)
contents.UpdateType("tool", map[string]interface{}{"text": contents.Text()}, currentMessageID)
isTool = false
Expand Down Expand Up @@ -485,12 +497,11 @@ func (ast *Assistant) streamChat(

output.Retry = ctx.Retry // Retry mode
output.Silent = ctx.Silent // Silent mode

if isFirst {
output.Assistant(ast.ID, ast.Name, ast.Avatar)
isFirst = false
}
output.Write(c.Writer)
output.Callback(cb).Write(c.Writer)
}

// Complete the stream
Expand All @@ -510,6 +521,7 @@ func (ast *Assistant) streamChat(
"retry": ctx.Retry,
"silent": ctx.Silent,
}).
Callback(cb).
Write(c.Writer)
}

Expand Down Expand Up @@ -544,7 +556,7 @@ func (ast *Assistant) streamChat(
output.Retry = ctx.Retry
output.Silent = ctx.Silent
}
output.Write(c.Writer)
output.Callback(cb).Write(c.Writer)
done <- true
return 0 // break
}
Expand All @@ -566,7 +578,7 @@ func (ast *Assistant) streamChat(
}
msg.Retry = ctx.Retry
msg.Silent = ctx.Silent
msg.Done().Write(c.Writer)
msg.Done().Callback(cb).Write(c.Writer)
}

return nil
Expand Down
75 changes: 69 additions & 6 deletions neo/assistant/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"strings"
"time"

"github.com/fatih/color"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
jsoniter "github.com/json-iterator/go"
"github.com/yaoapp/gou/process"
"github.com/yaoapp/gou/runtime/v8/bridge"
"github.com/yaoapp/kun/log"
chatctx "github.com/yaoapp/yao/neo/context"
Expand All @@ -19,7 +21,7 @@ import (
)

// HookInit initialize the assistant
func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []message.Message, options map[string]interface{}, contents *message.Contents) (*ResHookInit, error) {
func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []chatMessage.Message, options map[string]interface{}, contents *chatMessage.Contents) (*ResHookInit, error) {
// Create timeout context
ctx := ast.createBackgroundContext()
v, err := ast.call(ctx, "Init", c, contents, context, input, options)
Expand Down Expand Up @@ -310,7 +312,7 @@ func (ast *Assistant) call(ctx context.Context, method string, c *gin.Context, c

// Add sendMessage function to the script context
scriptCtx.WithFunction("SendMessage", sendMessage(c, contents))
scriptCtx.WithFunction("Run", run(c, context))
scriptCtx.WithFunction("Run", ast.run(c, context))

// Check if the method exists
if !scriptCtx.Global().Has(method) {
Expand All @@ -326,7 +328,7 @@ func (ast *Assistant) call(ctx context.Context, method string, c *gin.Context, c
}

// Execute the assistant
func run(c *gin.Context, context chatctx.Context) func(info *v8go.FunctionCallbackInfo) *v8go.Value {
func (ast *Assistant) run(c *gin.Context, context chatctx.Context) func(info *v8go.FunctionCallbackInfo) *v8go.Value {
return func(info *v8go.FunctionCallbackInfo) *v8go.Value {

// Get the args
Expand All @@ -345,11 +347,72 @@ func run(c *gin.Context, context chatctx.Context) func(info *v8go.FunctionCallba
}

// input []chatMessage.Message
var cb func(msg *chatMessage.Message)
input := args[1].String()
if len(args) > 2 {

goValue, err := bridge.GoValue(args[2], info.Context())
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}

name := ""
userArgs := []interface{}{}
switch v := goValue.(type) {
case string:
name = v
case map[string]interface{}:
if fname, ok := v["name"].(string); ok {
name = fname
}
if args, ok := v["args"].([]interface{}); ok {
userArgs = args
}
}

if strings.Contains(name, ".") {
cb = func(msg *chatMessage.Message) {
cbArgs := []interface{}{}
cbArgs = append(cbArgs, msg)
cbArgs = append(cbArgs, userArgs...)
p, err := process.Of(name, cbArgs...)
if err != nil {
log.Error("Failed to get the process: %s", err.Error())
color.Red("Failed to get the process: %s", err.Error())
return
}
err = p.Execute()
if err != nil {
log.Error("Failed to execute the process: %s", err.Error())
color.Red("Failed to execute the process: %s", err.Error())
return
}
defer p.Release()
}
}

// Call self method
cb = func(msg *chatMessage.Message) {
cbArgs := []interface{}{}
cbArgs = append(cbArgs, msg)
cbArgs = append(cbArgs, userArgs...)
ctx, err := ast.Script.NewContext(context.Sid, nil)
if err != nil {
return
}
defer ctx.Close()
_, err = ctx.CallWith(context, name, cbArgs...)
if err != nil {
log.Error("Failed to call the method: %s", err.Error())
color.Red("Failed to call the method: %s", err.Error())
return
}
}
}

options := map[string]interface{}{}
if len(args) > 2 {
optionsRaw, err := bridge.GoValue(args[2], info.Context())
if len(args) > 3 {
optionsRaw, err := bridge.GoValue(args[3], info.Context())
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}
Expand All @@ -374,7 +437,7 @@ func run(c *gin.Context, context chatctx.Context) func(info *v8go.FunctionCallba
context.AssistantID = assistantID
context.ChatID = fmt.Sprintf("chat_%s", uuid.New().String()) // New chat id
context.Silent = true // Silent mode
err = assistant.Execute(c, context, input, options) // Execute the assistant
err = assistant.Execute(c, context, input, options, cb) // Execute the assistant
if err != nil {
return bridge.JsException(info.Context(), err.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion neo/assistant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type API interface {
ReadBase64(ctx context.Context, fileID string) (string, error)

GetPlaceholder() *Placeholder
Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}) error
Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}, callback ...interface{}) error
Call(c *gin.Context, payload APIPayload) (interface{}, error)
}

Expand Down
25 changes: 25 additions & 0 deletions neo/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,26 @@ func (m *Message) Bind(data map[string]interface{}) *Message {
return m
}

// Callback callback the message
func (m *Message) Callback(fn interface{}) *Message {
if fn != nil {
switch v := fn.(type) {
case func(msg *Message):
v(m)
break

case func():
v()
break

default:
fmt.Println("no match callback")
break
}
}
return m
}

// Write writes the message to response writer
func (m *Message) Write(w gin.ResponseWriter) bool {
defer func() {
Expand All @@ -702,6 +722,11 @@ func (m *Message) Write(w gin.ResponseWriter) bool {
}
}()

// Ignore silent messages
if m.Silent {
return true
}

data, err := jsoniter.Marshal(m)
if err != nil {
log.Error("%s", err.Error())
Expand Down

0 comments on commit adc3d01

Please sign in to comment.