Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions internal/guard/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ func NewWasmGuardWithOptions(ctx context.Context, name string, wasmBytes []byte,
}
return name
}()).
// WithStartFunctions with no args suppresses automatic _start execution
// so guard loading cannot block on stdin or perform unexpected I/O.
WithStartFunctions().
WithStdin(strings.NewReader("")). // Isolate stdin
WithStdout(stdoutWriter). // Keep WASM stdout off gateway stdout (MCP stream)
Expand Down Expand Up @@ -469,6 +471,13 @@ func (g *WasmGuard) Name() string {
return g.name
}

// IsHealthy reports whether the guard is still usable after previous WASM calls.
func (g *WasmGuard) IsHealthy() bool {
g.mu.Lock()
defer g.mu.Unlock()
return !g.failed
}

// callWasmGuardFunction serialises WASM access, sets the backend reference, marshals
// inputData, logs the input, calls the named WASM export, and returns the raw result.
// All three public dispatch methods (LabelAgent, LabelResource, LabelResponse) share
Expand Down Expand Up @@ -641,12 +650,16 @@ func unmarshalWasmResponse(funcName string, data []byte) (map[string]any, error)

// Close releases WASM runtime resources
func (g *WasmGuard) Close(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
cleanupCtx := context.WithoutCancel(ctx)
var moduleErr, runtimeErr error
if g.module != nil {
moduleErr = g.module.Close(ctx)
moduleErr = g.module.Close(cleanupCtx)
}
if g.runtime != nil {
runtimeErr = g.runtime.Close(ctx)
runtimeErr = g.runtime.Close(cleanupCtx)
}
return errors.Join(moduleErr, runtimeErr)
}
32 changes: 31 additions & 1 deletion internal/guard/wasm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func TestWasmGuardContextPropagation(t *testing.T) {
defer cancel()

// Create a wazero runtime that will close when the context is done.
runtime := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfig().WithCloseOnContextDone(true))
runtime := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfigCompiler().WithCloseOnContextDone(true))
defer func() {
_ = runtime.Close(ctx)
}()
Expand Down Expand Up @@ -1120,6 +1120,21 @@ func TestWasmGuardClose(t *testing.T) {
err := guard.Close(context.Background())
assert.NoError(t, err)
})

t.Run("close ignores caller cancellation during cleanup", func(t *testing.T) {
ctx := context.Background()
rt := wazero.NewRuntime(ctx)
mod, err := rt.InstantiateWithConfig(ctx, minimalGuardWasm, wazero.NewModuleConfig().WithName("close-guard"))
require.NoError(t, err)

guard := &WasmGuard{runtime: rt, module: mod}

cancelledCtx, cancel := context.WithCancel(ctx)
cancel()

err = guard.Close(cancelledCtx)
assert.NoError(t, err)
})
}

func TestWasmGuardName(t *testing.T) {
Expand All @@ -1134,6 +1149,21 @@ func TestWasmGuardName(t *testing.T) {
})
}

func TestWasmGuardIsHealthy(t *testing.T) {
t.Run("healthy guard reports true", func(t *testing.T) {
guard := &WasmGuard{}
assert.True(t, guard.IsHealthy())
})

t.Run("failed guard reports false", func(t *testing.T) {
guard := &WasmGuard{
failed: true,
failedErr: errors.New("trap"),
}
assert.False(t, guard.IsHealthy())
})
}

func TestParsePathLabeledResponse(t *testing.T) {
t.Run("invalid JSON returns error", func(t *testing.T) {
invalidJSON := []byte("not json")
Expand Down
Loading