From 8e070dfe1a42ddaf2bdb877224e581154a1eab89 Mon Sep 17 00:00:00 2001 From: Marton Soos Date: Mon, 11 Mar 2024 10:20:13 +0100 Subject: [PATCH] Correctly use context in plugin and provide alternative _WithContext methods This commit ensures that the extism plugin struct does not store a context object in itself as this is not idiomatic Go and leads to unexpected behavior. Instead, a context object is passed to each plugin method call that runs a long running operation. To not break the interface of the SDK, the existing SDK method signatures were preserved, but they were changed to use context.Background under the hood. Next to these, alternative _WithContext methods were added, which allow the caller to provide a custom context object. --- extism.go | 48 +++++++++++++++++++++++++++++++++--------------- extism_test.go | 5 +++-- host.go | 21 ++++++++++++++++++--- runtime.go | 32 +++++++++++++++++--------------- 4 files changed, 71 insertions(+), 35 deletions(-) diff --git a/extism.go b/extism.go index 03c57e1..7dc6f71 100644 --- a/extism.go +++ b/extism.go @@ -28,7 +28,6 @@ type Runtime struct { Wazero wazero.Runtime Extism api.Module Env api.Module - ctx context.Context hasWasi bool } @@ -302,7 +301,12 @@ func (m *Manifest) UnmarshalJSON(data []byte) error { // Close closes the plugin by freeing the underlying resources. func (p *Plugin) Close() error { - return p.Runtime.Wazero.Close(p.Runtime.ctx) + return p.CloseWithContext(context.Background()) +} + +// Close closes the plugin by freeing the underlying resources. +func (p *Plugin) CloseWithContext(ctx context.Context) error { + return p.Runtime.Wazero.Close(ctx) } // NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions. @@ -351,17 +355,16 @@ func NewPlugin( Wazero: rt, Extism: extism, Env: env, - ctx: ctx, } if config.EnableWasi { - wasi_snapshot_preview1.MustInstantiate(c.ctx, c.Wazero) + wasi_snapshot_preview1.MustInstantiate(ctx, c.Wazero) c.hasWasi = true } for name, funcs := range hostModules { - _, err := buildHostModule(c.ctx, c.Wazero, name, funcs) + _, err := buildHostModule(ctx, c.Wazero, name, funcs) if err != nil { return nil, err } @@ -429,7 +432,7 @@ func NewPlugin( } } - m, err := c.Wazero.InstantiateWithConfig(c.ctx, data.Data, moduleConfig.WithName(data.Name)) + m, err := c.Wazero.InstantiateWithConfig(ctx, data.Data, moduleConfig.WithName(data.Name)) if err != nil { return nil, err } @@ -470,7 +473,7 @@ func NewPlugin( logLevel: logLevel, } - p.guestRuntime = detectGuestRuntime(p) + p.guestRuntime = detectGuestRuntime(ctx, p) return p, nil } @@ -482,29 +485,39 @@ func NewPlugin( // SetInput sets the input data for the plugin to be used in the next WebAssembly function call. func (plugin *Plugin) SetInput(data []byte) (uint64, error) { - _, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(plugin.Runtime.ctx) + return plugin.SetInputWithContext(context.Background(), data) +} + +// SetInput sets the input data for the plugin to be used in the next WebAssembly function call. +func (plugin *Plugin) SetInputWithContext(ctx context.Context, data []byte) (uint64, error) { + _, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(ctx) if err != nil { fmt.Println(err) return 0, errors.New("reset") } - ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(plugin.Runtime.ctx, uint64(len(data))) + ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(len(data))) if err != nil { return 0, err } plugin.Memory().Write(uint32(ptr[0]), data) - plugin.Runtime.Extism.ExportedFunction("input_set").Call(plugin.Runtime.ctx, ptr[0], uint64(len(data))) + plugin.Runtime.Extism.ExportedFunction("input_set").Call(ctx, ptr[0], uint64(len(data))) return ptr[0], nil } // GetOutput retrieves the output data from the last WebAssembly function call. func (plugin *Plugin) GetOutput() ([]byte, error) { - outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(plugin.Runtime.ctx) + return plugin.GetOutputWithContext(context.Background()) +} + +// GetOutput retrieves the output data from the last WebAssembly function call. +func (plugin *Plugin) GetOutputWithContext(ctx context.Context) ([]byte, error) { + outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(ctx) if err != nil { return []byte{}, err } - outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(plugin.Runtime.ctx) + outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(ctx) if err != nil { return []byte{}, err } @@ -524,7 +537,12 @@ func (plugin *Plugin) Memory() api.Memory { // GetError retrieves the error message from the last WebAssembly function call, if any. func (plugin *Plugin) GetError() string { - errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(plugin.Runtime.ctx) + return plugin.GetErrorWithContext(context.Background()) +} + +// GetError retrieves the error message from the last WebAssembly function call, if any. +func (plugin *Plugin) GetErrorWithContext(ctx context.Context) string { + errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(ctx) if err != nil { return "" } @@ -533,7 +551,7 @@ func (plugin *Plugin) GetError() string { return "" } - errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(plugin.Runtime.ctx, errOffs[0]) + errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, errOffs[0]) if err != nil { return "" } @@ -549,7 +567,7 @@ func (plugin *Plugin) FunctionExists(name string) bool { // Call a function by name with the given input, returning the output func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) { - return plugin.CallWithContext(plugin.Runtime.ctx, name, data) + return plugin.CallWithContext(context.Background(), name, data) } // Call a function by name with the given input and context, returning the output diff --git a/extism_test.go b/extism_test.go index d2349fe..635c9f3 100644 --- a/extism_test.go +++ b/extism_test.go @@ -518,7 +518,7 @@ func TestCancel(t *testing.T) { manifest := manifest("sleep.wasm") manifest.Config["duration"] = "3" // sleep for 3 seconds - ctx, cancel := context.WithCancel(context.Background()) + ctx := context.Background() config := PluginConfig{ ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(), EnableWasi: true, @@ -533,12 +533,13 @@ func TestCancel(t *testing.T) { defer plugin.Close() + ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(100 * time.Millisecond) cancel() }() - exit, _, err := plugin.Call("run_test", []byte{}) + exit, _, err := plugin.CallWithContext(ctx, "run_test", []byte{}) assert.Equal(t, sys.ExitCodeContextCanceled, exit, "Exit code must be `sys.ExitCodeContextCanceled`") assert.Equal(t, "module closed with context canceled", err.Error()) diff --git a/host.go b/host.go index e28e444..05fec11 100644 --- a/host.go +++ b/host.go @@ -123,7 +123,12 @@ func (p *CurrentPlugin) Memory() api.Memory { // Alloc a new memory block of the given length, returning its offset func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) { - out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(p.plugin.Runtime.ctx, uint64(n)) + return p.AllocWithContext(context.Background(), n) +} + +// Alloc a new memory block of the given length, returning its offset +func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(n)) if err != nil { return 0, err } else if len(out) != 1 { @@ -135,7 +140,12 @@ func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) { // Free the memory block specified by the given offset func (p *CurrentPlugin) Free(offset uint64) error { - _, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(p.plugin.Runtime.ctx, uint64(offset)) + return p.FreeWithContext(context.Background(), offset) +} + +// Free the memory block specified by the given offset +func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error { + _, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(ctx, uint64(offset)) if err != nil { return err } @@ -145,7 +155,12 @@ func (p *CurrentPlugin) Free(offset uint64) error { // Length returns the number of bytes allocated at the specified offset func (p *CurrentPlugin) Length(offs uint64) (uint64, error) { - out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(p.plugin.Runtime.ctx, uint64(offs)) + return p.LengthWithContext(context.Background(), offs) +} + +// Length returns the number of bytes allocated at the specified offset +func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, uint64(offs)) if err != nil { return 0, err } else if len(out) != 1 { diff --git a/runtime.go b/runtime.go index c4b14c4..dd39132 100644 --- a/runtime.go +++ b/runtime.go @@ -1,6 +1,8 @@ package extism import ( + "context" + "github.com/tetratelabs/wazero/api" ) @@ -20,15 +22,15 @@ type guestRuntime struct { initialized bool } -func detectGuestRuntime(p *Plugin) guestRuntime { +func detectGuestRuntime(ctx context.Context, p *Plugin) guestRuntime { m := p.Main - runtime, ok := haskellRuntime(p, m) + runtime, ok := haskellRuntime(ctx, p, m) if ok { return runtime } - runtime, ok = wasiRuntime(p, m) + runtime, ok = wasiRuntime(ctx, p, m) if ok { return runtime } @@ -40,7 +42,7 @@ func detectGuestRuntime(p *Plugin) guestRuntime { // Check for Haskell runtime initialization functions // Initialize Haskell runtime if `hs_init` and `hs_exit` are present, // by calling the `hs_init` export -func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { +func haskellRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) { initFunc := m.ExportedFunction("hs_init") if initFunc == nil { return guestRuntime{}, false @@ -56,12 +58,12 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { init := func() error { if reactorInit != nil { - _, err := reactorInit.Call(p.Runtime.ctx) + _, err := reactorInit.Call(ctx) if err != nil { p.Logf(LogLevelError, "Error running reactor _initialize: %s", err.Error()) } } - _, err := initFunc.Call(p.Runtime.ctx, 0, 0) + _, err := initFunc.Call(ctx, 0, 0) if err == nil { p.Log(LogLevelDebug, "Initialized Haskell language runtime.") } @@ -74,7 +76,7 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { } // Check for initialization functions defined by the WASI standard -func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { +func wasiRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) { if !p.Runtime.hasWasi { return guestRuntime{}, false } @@ -82,16 +84,16 @@ func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { // WASI supports two modules: Reactors and Commands // we prioritize Reactors over Commands // see: https://github.com/WebAssembly/WASI/blob/main/legacy/application-abi.md - if r, ok := reactorModule(m, p); ok { + if r, ok := reactorModule(ctx, m, p); ok { return r, ok } - return commandModule(m, p) + return commandModule(ctx, m, p) } // Check for `_initialize` this is used by WASI to initialize certain interfaces. -func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) { - init := findFunc(m, p, "_initialize") +func reactorModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) { + init := findFunc(ctx, m, p, "_initialize") if init == nil { return guestRuntime{}, false } @@ -104,8 +106,8 @@ func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) { // Check for `__wasm__call_ctors`, this is used by WASI to // initialize certain interfaces. -func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) { - init := findFunc(m, p, "__wasm_call_ctors") +func commandModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) { + init := findFunc(ctx, m, p, "__wasm_call_ctors") if init == nil { return guestRuntime{}, false } @@ -116,7 +118,7 @@ func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) { return guestRuntime{runtimeType: Wasi, init: init}, true } -func findFunc(m api.Module, p *Plugin, name string) func() error { +func findFunc(ctx context.Context, m api.Module, p *Plugin, name string) func() error { initFunc := m.ExportedFunction(name) if initFunc == nil { return nil @@ -130,7 +132,7 @@ func findFunc(m api.Module, p *Plugin, name string) func() error { return func() error { p.Logf(LogLevelDebug, "Calling %v", name) - _, err := initFunc.Call(p.Runtime.ctx) + _, err := initFunc.Call(ctx) return err } }