Skip to content

Commit

Permalink
Correctly use context in plugin and provide alternative _WithContext …
Browse files Browse the repository at this point in the history
…methods

This commit ensures that the extism plugin struct does
not store a context object in itself as this is not idiomatic
Go and leads to unexpected behavior. Instead, a context object
is passed to each plugin method call that runs a long running
operation.

To not break the interface of the SDK, the existing SDK method
signatures were preserved, but they were changed to use
context.Background under the hood. Next to these, alternative
_WithContext methods were added, which allow the caller to
provide a custom context object.
  • Loading branch information
Marton6 committed Mar 11, 2024
1 parent a1a2815 commit 8e070df
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 35 deletions.
48 changes: 33 additions & 15 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ type Runtime struct {
Wazero wazero.Runtime
Extism api.Module
Env api.Module
ctx context.Context
hasWasi bool
}

Expand Down Expand Up @@ -302,7 +301,12 @@ func (m *Manifest) UnmarshalJSON(data []byte) error {

// Close closes the plugin by freeing the underlying resources.
func (p *Plugin) Close() error {
return p.Runtime.Wazero.Close(p.Runtime.ctx)
return p.CloseWithContext(context.Background())
}

// Close closes the plugin by freeing the underlying resources.
func (p *Plugin) CloseWithContext(ctx context.Context) error {
return p.Runtime.Wazero.Close(ctx)
}

// NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions.
Expand Down Expand Up @@ -351,17 +355,16 @@ func NewPlugin(
Wazero: rt,
Extism: extism,
Env: env,
ctx: ctx,
}

if config.EnableWasi {
wasi_snapshot_preview1.MustInstantiate(c.ctx, c.Wazero)
wasi_snapshot_preview1.MustInstantiate(ctx, c.Wazero)

c.hasWasi = true
}

for name, funcs := range hostModules {
_, err := buildHostModule(c.ctx, c.Wazero, name, funcs)
_, err := buildHostModule(ctx, c.Wazero, name, funcs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -429,7 +432,7 @@ func NewPlugin(
}
}

m, err := c.Wazero.InstantiateWithConfig(c.ctx, data.Data, moduleConfig.WithName(data.Name))
m, err := c.Wazero.InstantiateWithConfig(ctx, data.Data, moduleConfig.WithName(data.Name))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -470,7 +473,7 @@ func NewPlugin(
logLevel: logLevel,
}

p.guestRuntime = detectGuestRuntime(p)
p.guestRuntime = detectGuestRuntime(ctx, p)
return p, nil
}

Expand All @@ -482,29 +485,39 @@ func NewPlugin(

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
func (plugin *Plugin) SetInput(data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(plugin.Runtime.ctx)
return plugin.SetInputWithContext(context.Background(), data)
}

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
func (plugin *Plugin) SetInputWithContext(ctx context.Context, data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(ctx)
if err != nil {
fmt.Println(err)
return 0, errors.New("reset")
}

ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(plugin.Runtime.ctx, uint64(len(data)))
ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(len(data)))
if err != nil {
return 0, err
}
plugin.Memory().Write(uint32(ptr[0]), data)
plugin.Runtime.Extism.ExportedFunction("input_set").Call(plugin.Runtime.ctx, ptr[0], uint64(len(data)))
plugin.Runtime.Extism.ExportedFunction("input_set").Call(ctx, ptr[0], uint64(len(data)))
return ptr[0], nil
}

// GetOutput retrieves the output data from the last WebAssembly function call.
func (plugin *Plugin) GetOutput() ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(plugin.Runtime.ctx)
return plugin.GetOutputWithContext(context.Background())
}

// GetOutput retrieves the output data from the last WebAssembly function call.
func (plugin *Plugin) GetOutputWithContext(ctx context.Context) ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(ctx)
if err != nil {
return []byte{}, err
}

outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(plugin.Runtime.ctx)
outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(ctx)
if err != nil {
return []byte{}, err
}
Expand All @@ -524,7 +537,12 @@ func (plugin *Plugin) Memory() api.Memory {

// GetError retrieves the error message from the last WebAssembly function call, if any.
func (plugin *Plugin) GetError() string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(plugin.Runtime.ctx)
return plugin.GetErrorWithContext(context.Background())
}

// GetError retrieves the error message from the last WebAssembly function call, if any.
func (plugin *Plugin) GetErrorWithContext(ctx context.Context) string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(ctx)
if err != nil {
return ""
}
Expand All @@ -533,7 +551,7 @@ func (plugin *Plugin) GetError() string {
return ""
}

errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(plugin.Runtime.ctx, errOffs[0])
errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, errOffs[0])
if err != nil {
return ""
}
Expand All @@ -549,7 +567,7 @@ func (plugin *Plugin) FunctionExists(name string) bool {

// Call a function by name with the given input, returning the output
func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) {
return plugin.CallWithContext(plugin.Runtime.ctx, name, data)
return plugin.CallWithContext(context.Background(), name, data)
}

// Call a function by name with the given input and context, returning the output
Expand Down
5 changes: 3 additions & 2 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func TestCancel(t *testing.T) {
manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "3" // sleep for 3 seconds

ctx, cancel := context.WithCancel(context.Background())
ctx := context.Background()
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
Expand All @@ -533,12 +533,13 @@ func TestCancel(t *testing.T) {

defer plugin.Close()

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()

exit, _, err := plugin.Call("run_test", []byte{})
exit, _, err := plugin.CallWithContext(ctx, "run_test", []byte{})

assert.Equal(t, sys.ExitCodeContextCanceled, exit, "Exit code must be `sys.ExitCodeContextCanceled`")
assert.Equal(t, "module closed with context canceled", err.Error())
Expand Down
21 changes: 18 additions & 3 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ func (p *CurrentPlugin) Memory() api.Memory {

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(p.plugin.Runtime.ctx, uint64(n))
return p.AllocWithContext(context.Background(), n)
}

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(n))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand All @@ -135,7 +140,12 @@ func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {

// Free the memory block specified by the given offset
func (p *CurrentPlugin) Free(offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(p.plugin.Runtime.ctx, uint64(offset))
return p.FreeWithContext(context.Background(), offset)
}

// Free the memory block specified by the given offset
func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(ctx, uint64(offset))
if err != nil {
return err
}
Expand All @@ -145,7 +155,12 @@ func (p *CurrentPlugin) Free(offset uint64) error {

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) Length(offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(p.plugin.Runtime.ctx, uint64(offs))
return p.LengthWithContext(context.Background(), offs)
}

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, uint64(offs))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand Down
32 changes: 17 additions & 15 deletions runtime.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package extism

import (
"context"

"github.com/tetratelabs/wazero/api"
)

Expand All @@ -20,15 +22,15 @@ type guestRuntime struct {
initialized bool
}

func detectGuestRuntime(p *Plugin) guestRuntime {
func detectGuestRuntime(ctx context.Context, p *Plugin) guestRuntime {
m := p.Main

runtime, ok := haskellRuntime(p, m)
runtime, ok := haskellRuntime(ctx, p, m)
if ok {
return runtime
}

runtime, ok = wasiRuntime(p, m)
runtime, ok = wasiRuntime(ctx, p, m)
if ok {
return runtime
}
Expand All @@ -40,7 +42,7 @@ func detectGuestRuntime(p *Plugin) guestRuntime {
// Check for Haskell runtime initialization functions
// Initialize Haskell runtime if `hs_init` and `hs_exit` are present,
// by calling the `hs_init` export
func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
func haskellRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) {
initFunc := m.ExportedFunction("hs_init")
if initFunc == nil {
return guestRuntime{}, false
Expand All @@ -56,12 +58,12 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {

init := func() error {
if reactorInit != nil {
_, err := reactorInit.Call(p.Runtime.ctx)
_, err := reactorInit.Call(ctx)
if err != nil {
p.Logf(LogLevelError, "Error running reactor _initialize: %s", err.Error())
}
}
_, err := initFunc.Call(p.Runtime.ctx, 0, 0)
_, err := initFunc.Call(ctx, 0, 0)
if err == nil {
p.Log(LogLevelDebug, "Initialized Haskell language runtime.")
}
Expand All @@ -74,24 +76,24 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
}

// Check for initialization functions defined by the WASI standard
func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
func wasiRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) {
if !p.Runtime.hasWasi {
return guestRuntime{}, false
}

// WASI supports two modules: Reactors and Commands
// we prioritize Reactors over Commands
// see: https://github.com/WebAssembly/WASI/blob/main/legacy/application-abi.md
if r, ok := reactorModule(m, p); ok {
if r, ok := reactorModule(ctx, m, p); ok {
return r, ok
}

return commandModule(m, p)
return commandModule(ctx, m, p)
}

// Check for `_initialize` this is used by WASI to initialize certain interfaces.
func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "_initialize")
func reactorModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(ctx, m, p, "_initialize")
if init == nil {
return guestRuntime{}, false
}
Expand All @@ -104,8 +106,8 @@ func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) {

// Check for `__wasm__call_ctors`, this is used by WASI to
// initialize certain interfaces.
func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "__wasm_call_ctors")
func commandModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(ctx, m, p, "__wasm_call_ctors")
if init == nil {
return guestRuntime{}, false
}
Expand All @@ -116,7 +118,7 @@ func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) {
return guestRuntime{runtimeType: Wasi, init: init}, true
}

func findFunc(m api.Module, p *Plugin, name string) func() error {
func findFunc(ctx context.Context, m api.Module, p *Plugin, name string) func() error {
initFunc := m.ExportedFunction(name)
if initFunc == nil {
return nil
Expand All @@ -130,7 +132,7 @@ func findFunc(m api.Module, p *Plugin, name string) func() error {

return func() error {
p.Logf(LogLevelDebug, "Calling %v", name)
_, err := initFunc.Call(p.Runtime.ctx)
_, err := initFunc.Call(ctx)
return err
}
}
Expand Down

0 comments on commit 8e070df

Please sign in to comment.