diff --git a/extism.go b/extism.go index 03c57e1..7dc6f71 100644 --- a/extism.go +++ b/extism.go @@ -28,7 +28,6 @@ type Runtime struct { Wazero wazero.Runtime Extism api.Module Env api.Module - ctx context.Context hasWasi bool } @@ -302,7 +301,12 @@ func (m *Manifest) UnmarshalJSON(data []byte) error { // Close closes the plugin by freeing the underlying resources. func (p *Plugin) Close() error { - return p.Runtime.Wazero.Close(p.Runtime.ctx) + return p.CloseWithContext(context.Background()) +} + +// Close closes the plugin by freeing the underlying resources. +func (p *Plugin) CloseWithContext(ctx context.Context) error { + return p.Runtime.Wazero.Close(ctx) } // NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions. @@ -351,17 +355,16 @@ func NewPlugin( Wazero: rt, Extism: extism, Env: env, - ctx: ctx, } if config.EnableWasi { - wasi_snapshot_preview1.MustInstantiate(c.ctx, c.Wazero) + wasi_snapshot_preview1.MustInstantiate(ctx, c.Wazero) c.hasWasi = true } for name, funcs := range hostModules { - _, err := buildHostModule(c.ctx, c.Wazero, name, funcs) + _, err := buildHostModule(ctx, c.Wazero, name, funcs) if err != nil { return nil, err } @@ -429,7 +432,7 @@ func NewPlugin( } } - m, err := c.Wazero.InstantiateWithConfig(c.ctx, data.Data, moduleConfig.WithName(data.Name)) + m, err := c.Wazero.InstantiateWithConfig(ctx, data.Data, moduleConfig.WithName(data.Name)) if err != nil { return nil, err } @@ -470,7 +473,7 @@ func NewPlugin( logLevel: logLevel, } - p.guestRuntime = detectGuestRuntime(p) + p.guestRuntime = detectGuestRuntime(ctx, p) return p, nil } @@ -482,29 +485,39 @@ func NewPlugin( // SetInput sets the input data for the plugin to be used in the next WebAssembly function call. func (plugin *Plugin) SetInput(data []byte) (uint64, error) { - _, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(plugin.Runtime.ctx) + return plugin.SetInputWithContext(context.Background(), data) +} + +// SetInput sets the input data for the plugin to be used in the next WebAssembly function call. +func (plugin *Plugin) SetInputWithContext(ctx context.Context, data []byte) (uint64, error) { + _, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(ctx) if err != nil { fmt.Println(err) return 0, errors.New("reset") } - ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(plugin.Runtime.ctx, uint64(len(data))) + ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(len(data))) if err != nil { return 0, err } plugin.Memory().Write(uint32(ptr[0]), data) - plugin.Runtime.Extism.ExportedFunction("input_set").Call(plugin.Runtime.ctx, ptr[0], uint64(len(data))) + plugin.Runtime.Extism.ExportedFunction("input_set").Call(ctx, ptr[0], uint64(len(data))) return ptr[0], nil } // GetOutput retrieves the output data from the last WebAssembly function call. func (plugin *Plugin) GetOutput() ([]byte, error) { - outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(plugin.Runtime.ctx) + return plugin.GetOutputWithContext(context.Background()) +} + +// GetOutput retrieves the output data from the last WebAssembly function call. +func (plugin *Plugin) GetOutputWithContext(ctx context.Context) ([]byte, error) { + outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(ctx) if err != nil { return []byte{}, err } - outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(plugin.Runtime.ctx) + outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(ctx) if err != nil { return []byte{}, err } @@ -524,7 +537,12 @@ func (plugin *Plugin) Memory() api.Memory { // GetError retrieves the error message from the last WebAssembly function call, if any. func (plugin *Plugin) GetError() string { - errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(plugin.Runtime.ctx) + return plugin.GetErrorWithContext(context.Background()) +} + +// GetError retrieves the error message from the last WebAssembly function call, if any. +func (plugin *Plugin) GetErrorWithContext(ctx context.Context) string { + errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(ctx) if err != nil { return "" } @@ -533,7 +551,7 @@ func (plugin *Plugin) GetError() string { return "" } - errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(plugin.Runtime.ctx, errOffs[0]) + errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, errOffs[0]) if err != nil { return "" } @@ -549,7 +567,7 @@ func (plugin *Plugin) FunctionExists(name string) bool { // Call a function by name with the given input, returning the output func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) { - return plugin.CallWithContext(plugin.Runtime.ctx, name, data) + return plugin.CallWithContext(context.Background(), name, data) } // Call a function by name with the given input and context, returning the output diff --git a/extism_test.go b/extism_test.go index d2349fe..635c9f3 100644 --- a/extism_test.go +++ b/extism_test.go @@ -518,7 +518,7 @@ func TestCancel(t *testing.T) { manifest := manifest("sleep.wasm") manifest.Config["duration"] = "3" // sleep for 3 seconds - ctx, cancel := context.WithCancel(context.Background()) + ctx := context.Background() config := PluginConfig{ ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(), EnableWasi: true, @@ -533,12 +533,13 @@ func TestCancel(t *testing.T) { defer plugin.Close() + ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(100 * time.Millisecond) cancel() }() - exit, _, err := plugin.Call("run_test", []byte{}) + exit, _, err := plugin.CallWithContext(ctx, "run_test", []byte{}) assert.Equal(t, sys.ExitCodeContextCanceled, exit, "Exit code must be `sys.ExitCodeContextCanceled`") assert.Equal(t, "module closed with context canceled", err.Error()) diff --git a/host.go b/host.go index e28e444..05fec11 100644 --- a/host.go +++ b/host.go @@ -123,7 +123,12 @@ func (p *CurrentPlugin) Memory() api.Memory { // Alloc a new memory block of the given length, returning its offset func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) { - out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(p.plugin.Runtime.ctx, uint64(n)) + return p.AllocWithContext(context.Background(), n) +} + +// Alloc a new memory block of the given length, returning its offset +func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(n)) if err != nil { return 0, err } else if len(out) != 1 { @@ -135,7 +140,12 @@ func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) { // Free the memory block specified by the given offset func (p *CurrentPlugin) Free(offset uint64) error { - _, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(p.plugin.Runtime.ctx, uint64(offset)) + return p.FreeWithContext(context.Background(), offset) +} + +// Free the memory block specified by the given offset +func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error { + _, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(ctx, uint64(offset)) if err != nil { return err } @@ -145,7 +155,12 @@ func (p *CurrentPlugin) Free(offset uint64) error { // Length returns the number of bytes allocated at the specified offset func (p *CurrentPlugin) Length(offs uint64) (uint64, error) { - out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(p.plugin.Runtime.ctx, uint64(offs)) + return p.LengthWithContext(context.Background(), offs) +} + +// Length returns the number of bytes allocated at the specified offset +func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, uint64(offs)) if err != nil { return 0, err } else if len(out) != 1 { diff --git a/runtime.go b/runtime.go index c4b14c4..dd39132 100644 --- a/runtime.go +++ b/runtime.go @@ -1,6 +1,8 @@ package extism import ( + "context" + "github.com/tetratelabs/wazero/api" ) @@ -20,15 +22,15 @@ type guestRuntime struct { initialized bool } -func detectGuestRuntime(p *Plugin) guestRuntime { +func detectGuestRuntime(ctx context.Context, p *Plugin) guestRuntime { m := p.Main - runtime, ok := haskellRuntime(p, m) + runtime, ok := haskellRuntime(ctx, p, m) if ok { return runtime } - runtime, ok = wasiRuntime(p, m) + runtime, ok = wasiRuntime(ctx, p, m) if ok { return runtime } @@ -40,7 +42,7 @@ func detectGuestRuntime(p *Plugin) guestRuntime { // Check for Haskell runtime initialization functions // Initialize Haskell runtime if `hs_init` and `hs_exit` are present, // by calling the `hs_init` export -func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { +func haskellRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) { initFunc := m.ExportedFunction("hs_init") if initFunc == nil { return guestRuntime{}, false @@ -56,12 +58,12 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { init := func() error { if reactorInit != nil { - _, err := reactorInit.Call(p.Runtime.ctx) + _, err := reactorInit.Call(ctx) if err != nil { p.Logf(LogLevelError, "Error running reactor _initialize: %s", err.Error()) } } - _, err := initFunc.Call(p.Runtime.ctx, 0, 0) + _, err := initFunc.Call(ctx, 0, 0) if err == nil { p.Log(LogLevelDebug, "Initialized Haskell language runtime.") } @@ -74,7 +76,7 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { } // Check for initialization functions defined by the WASI standard -func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { +func wasiRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) { if !p.Runtime.hasWasi { return guestRuntime{}, false } @@ -82,16 +84,16 @@ func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { // WASI supports two modules: Reactors and Commands // we prioritize Reactors over Commands // see: https://github.com/WebAssembly/WASI/blob/main/legacy/application-abi.md - if r, ok := reactorModule(m, p); ok { + if r, ok := reactorModule(ctx, m, p); ok { return r, ok } - return commandModule(m, p) + return commandModule(ctx, m, p) } // Check for `_initialize` this is used by WASI to initialize certain interfaces. -func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) { - init := findFunc(m, p, "_initialize") +func reactorModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) { + init := findFunc(ctx, m, p, "_initialize") if init == nil { return guestRuntime{}, false } @@ -104,8 +106,8 @@ func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) { // Check for `__wasm__call_ctors`, this is used by WASI to // initialize certain interfaces. -func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) { - init := findFunc(m, p, "__wasm_call_ctors") +func commandModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) { + init := findFunc(ctx, m, p, "__wasm_call_ctors") if init == nil { return guestRuntime{}, false } @@ -116,7 +118,7 @@ func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) { return guestRuntime{runtimeType: Wasi, init: init}, true } -func findFunc(m api.Module, p *Plugin, name string) func() error { +func findFunc(ctx context.Context, m api.Module, p *Plugin, name string) func() error { initFunc := m.ExportedFunction(name) if initFunc == nil { return nil @@ -130,7 +132,7 @@ func findFunc(m api.Module, p *Plugin, name string) func() error { return func() error { p.Logf(LogLevelDebug, "Calling %v", name) - _, err := initFunc.Call(p.Runtime.ctx) + _, err := initFunc.Call(ctx) return err } }