diff --git a/api/transport/propagation.go b/api/transport/propagation.go index 7c27ba106..28e5e849a 100644 --- a/api/transport/propagation.go +++ b/api/transport/propagation.go @@ -22,6 +22,8 @@ package transport import ( "context" + "strings" + "sync" "time" "github.com/opentracing/opentracing-go" @@ -29,6 +31,11 @@ import ( opentracinglog "github.com/opentracing/opentracing-go/log" ) +const ( + tchannelTracingKeyPrefix = "$tracing$" + tchannelTracingKeyMappingSize = 100 +) + // CreateOpenTracingSpan creates a new context with a started span type CreateOpenTracingSpan struct { Tracer opentracing.Tracer @@ -119,3 +126,98 @@ func UpdateSpanWithErr(span opentracing.Span, err error) error { } return err } + +// GetPropagationFormat returns the opentracing propagation depends on transport. +// For TChannel, the format is opentracing.TextMap +// For HTTP and gRPC, the format is opentracing.HTTPHeaders +func GetPropagationFormat(transport string) opentracing.BuiltinFormat { + if transport == "tchannel" { + return opentracing.TextMap + } + return opentracing.HTTPHeaders +} + +// PropagationCarrier is an interface to combine both reader and writer interface +type PropagationCarrier interface { + opentracing.TextMapReader + opentracing.TextMapWriter +} + +// GetPropagationCarrier get the propagation carrier depends on the transport. +// The carrier is used for accessing the transport headers. +// For TChannel, a special carrier is used. For details, see comments of TChannelHeadersCarrier +func GetPropagationCarrier(headers map[string]string, transport string) PropagationCarrier { + if transport == "tchannel" { + return TChannelHeadersCarrier(headers) + } + return opentracing.TextMapCarrier(headers) +} + +// TChannelHeadersCarrier is a dedicated carrier for TChannel. +// When writing the tracing headers into headers, the $tracing$ prefix is added to each tracing header key. +// When reading the tracing headers from headers, the $tracing$ prefix is removed from each tracing header key. +type TChannelHeadersCarrier map[string]string + +var _ PropagationCarrier = TChannelHeadersCarrier{} + +// ForeachKey iterates over all tracing headers in the carrier, applying the provided +// handler function to each header after stripping the $tracing$ prefix from the keys. +func (c TChannelHeadersCarrier) ForeachKey(handler func(string, string) error) error { + for k, v := range c { + if !strings.HasPrefix(k, tchannelTracingKeyPrefix) { + continue + } + noPrefixKey := tchannelTracingKeyDecoding.mapAndCache(k) + if err := handler(noPrefixKey, v); err != nil { + return err + } + } + return nil +} + +// Set adds a tracing header to the carrier, prefixing the key with $tracing$ before storing it. +func (c TChannelHeadersCarrier) Set(key, value string) { + prefixedKey := tchannelTracingKeyEncoding.mapAndCache(key) + c[prefixedKey] = value +} + +// tchannelTracingKeysMapping is to optimize the efficiency of tracing header key manipulations. +// The implementation is forked from tchannel-go: https://github.com/uber/tchannel-go/blob/dev/tracing_keys.go#L36 +type tchannelTracingKeysMapping struct { + sync.RWMutex + mapping map[string]string + mapper func(key string) string +} + +var tchannelTracingKeyEncoding = &tchannelTracingKeysMapping{ + mapping: make(map[string]string), + mapper: func(key string) string { + return tchannelTracingKeyPrefix + key + }, +} + +var tchannelTracingKeyDecoding = &tchannelTracingKeysMapping{ + mapping: make(map[string]string), + mapper: func(key string) string { + return key[len(tchannelTracingKeyPrefix):] + }, +} + +func (m *tchannelTracingKeysMapping) mapAndCache(key string) string { + m.RLock() + v, ok := m.mapping[key] + m.RUnlock() + if ok { + return v + } + m.Lock() + defer m.Unlock() + if v, ok := m.mapping[key]; ok { + return v + } + mappedKey := m.mapper(key) + if len(m.mapping) < tchannelTracingKeyMappingSize { + m.mapping[key] = mappedKey + } + return mappedKey +} diff --git a/internal/tracinginterceptor/interceptor.go b/internal/tracinginterceptor/interceptor.go new file mode 100644 index 000000000..03d813b4e --- /dev/null +++ b/internal/tracinginterceptor/interceptor.go @@ -0,0 +1,263 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tracinginterceptor + +import ( + "context" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" + "github.com/opentracing/opentracing-go/log" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/internal/transportinterceptor" + "go.uber.org/yarpc/yarpcerrors" +) + +var ( + _ transportinterceptor.UnaryInbound = (*Interceptor)(nil) + _ transportinterceptor.UnaryOutbound = (*Interceptor)(nil) + _ transportinterceptor.OnewayInbound = (*Interceptor)(nil) + _ transportinterceptor.OnewayOutbound = (*Interceptor)(nil) + _ transportinterceptor.StreamInbound = (*Interceptor)(nil) + _ transportinterceptor.StreamOutbound = (*Interceptor)(nil) +) + +// Params defines the parameters for creating the Middleware +type Params struct { + Tracer opentracing.Tracer + Transport string +} + +// Interceptor is the tracing interceptor for all RPC types. +// It handles both observability and inter-process context propagation. +type Interceptor struct { + tracer opentracing.Tracer + transport string + propagationFormat opentracing.BuiltinFormat +} + +// New constructs a tracing interceptor with the provided configuration. +func New(p Params) *Interceptor { + m := &Interceptor{ + tracer: p.Tracer, + transport: p.Transport, + propagationFormat: transport.GetPropagationFormat(p.Transport), + } + if m.tracer == nil { + m.tracer = opentracing.GlobalTracer() + } + return m +} + +// Handle is the tracing handler for Unary Inbound requests. +// It creates a new span, applies tracing tags, and propagates the span context to the downstream handler. +func (m *Interceptor) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error { + parentSpanCtx, _ := m.tracer.Extract(m.propagationFormat, transport.GetPropagationCarrier(req.Headers.Items(), req.Transport)) + tags := ExtractTracingTags(req) + + extractOpenTracingSpan := &transport.ExtractOpenTracingSpan{ + ParentSpanContext: parentSpanCtx, + Tracer: m.tracer, + TransportName: req.Transport, + StartTime: time.Now(), + ExtraTags: tags, + } + ctx, span := extractOpenTracingSpan.Do(ctx, req) + defer span.Finish() + + err := h.Handle(ctx, req, resw) + return updateSpanWithError(span, err) +} + +// Call is the tracing handler for Unary Outbound requests. +// It creates a new span for the outbound request, applies tracing tags, and propagates the span context to the downstream outbound handler. +func (m *Interceptor) Call(ctx context.Context, req *transport.Request, out transport.UnaryOutbound) (*transport.Response, error) { + tags := ExtractTracingTags(req) + + createOpenTracingSpan := &transport.CreateOpenTracingSpan{ + Tracer: m.tracer, + TransportName: m.transport, + StartTime: time.Now(), + ExtraTags: tags, + } + ctx, span := createOpenTracingSpan.Do(ctx, req) + defer span.Finish() + + tracingHeaders := make(map[string]string) + if err := m.tracer.Inject(span.Context(), m.propagationFormat, transport.GetPropagationCarrier(tracingHeaders, m.transport)); err != nil { + ext.Error.Set(span, true) + span.LogFields(log.String("event", "error"), log.String("message", err.Error())) + return nil, err + } + + for k, v := range tracingHeaders { + req.Headers = req.Headers.With(k, v) + } + + res, err := out.Call(ctx, req) + return res, updateSpanWithOutboundError(span, res, err) +} + +// HandleOneway is the tracing handler for Oneway Inbound requests. +// It creates a new span for the inbound request, applies tracing tags, and propagates the span context to the downstream handler. +func (m *Interceptor) HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error { + parentSpanCtx, _ := m.tracer.Extract(m.propagationFormat, transport.GetPropagationCarrier(req.Headers.Items(), req.Transport)) + tags := ExtractTracingTags(req) + + extractOpenTracingSpan := &transport.ExtractOpenTracingSpan{ + ParentSpanContext: parentSpanCtx, + Tracer: m.tracer, + TransportName: req.Transport, + StartTime: time.Now(), + ExtraTags: tags, + } + ctx, span := extractOpenTracingSpan.Do(ctx, req) + defer span.Finish() + + err := h.HandleOneway(ctx, req) + return updateSpanWithError(span, err) +} + +// CallOneway is the tracing handler for Oneway Outbound requests. +// It creates a new span for the outbound request, applies tracing tags, and propagates the span context to the downstream outbound handler. +func (m *Interceptor) CallOneway(ctx context.Context, req *transport.Request, out transport.OnewayOutbound) (transport.Ack, error) { + tags := ExtractTracingTags(req) + + createOpenTracingSpan := &transport.CreateOpenTracingSpan{ + Tracer: m.tracer, + TransportName: m.transport, + StartTime: time.Now(), + ExtraTags: tags, + } + ctx, span := createOpenTracingSpan.Do(ctx, req) + defer span.Finish() + + tracingHeaders := make(map[string]string) + if err := m.tracer.Inject(span.Context(), m.propagationFormat, transport.GetPropagationCarrier(tracingHeaders, m.transport)); err != nil { + ext.Error.Set(span, true) + span.LogFields(log.String("event", "error"), log.String("message", err.Error())) + return nil, err + } + + for k, v := range tracingHeaders { + req.Headers = req.Headers.With(k, v) + } + + ack, err := out.CallOneway(ctx, req) + return ack, updateSpanWithError(span, err) +} + +// HandleStream is the tracing handler for Stream Inbound requests. +// It creates a new span for the inbound stream request, applies tracing tags, and propagates the span context to the downstream handler. +func (m *Interceptor) HandleStream(s *transport.ServerStream, h transport.StreamHandler) error { + meta := s.Request().Meta + parentSpanCtx, _ := m.tracer.Extract(m.propagationFormat, transport.GetPropagationCarrier(meta.Headers.Items(), meta.Transport)) + + tags := ExtractTracingTags(meta.ToRequest()) + + extractOpenTracingSpan := &transport.ExtractOpenTracingSpan{ + ParentSpanContext: parentSpanCtx, + Tracer: m.tracer, + TransportName: meta.Transport, + StartTime: time.Now(), + ExtraTags: tags, + } + _, span := extractOpenTracingSpan.Do(s.Context(), meta.ToRequest()) + defer span.Finish() + + err := h.HandleStream(s) + return updateSpanWithError(span, err) +} + +// CallStream is the tracing handler for Stream Outbound requests. +// It creates a new span for the outbound stream request, applies tracing tags, and propagates the span context to the downstream outbound handler. +func (m *Interceptor) CallStream(ctx context.Context, req *transport.StreamRequest, out transport.StreamOutbound) (*transport.ClientStream, error) { + tags := ExtractTracingTags(req.Meta.ToRequest()) + + createOpenTracingSpan := &transport.CreateOpenTracingSpan{ + Tracer: m.tracer, + TransportName: m.transport, + StartTime: time.Now(), + ExtraTags: tags, + } + ctx, span := createOpenTracingSpan.Do(ctx, req.Meta.ToRequest()) + defer span.Finish() + + tracingHeaders := make(map[string]string) + if err := m.tracer.Inject(span.Context(), m.propagationFormat, transport.GetPropagationCarrier(tracingHeaders, m.transport)); err != nil { + ext.Error.Set(span, true) + span.LogFields(log.String("event", "error"), log.String("message", err.Error())) + return nil, err + } + + for k, v := range tracingHeaders { + req.Meta.Headers = req.Meta.Headers.With(k, v) + } + clientStream, err := out.CallStream(ctx, req) + + return clientStream, updateSpanWithError(span, err) +} + +func updateSpanWithError(span opentracing.Span, err error) error { + if err == nil { + return err + } + + ext.Error.Set(span, true) + if yarpcerrors.IsStatus(err) { + status := yarpcerrors.FromError(err) + errCode := status.Code() + span.SetTag("rpc.yarpc.status_code", errCode.String()) + span.SetTag("error.type", errCode.String()) + return err + } + + span.SetTag("error.type", "unknown_internal_yarpc") + return err +} + +func updateSpanWithOutboundError(span opentracing.Span, res *transport.Response, err error) error { + isApplicationError := false + if res != nil { + isApplicationError = res.ApplicationError + } + if err == nil && !isApplicationError { + return err + } + + ext.Error.Set(span, true) + if yarpcerrors.IsStatus(err) { + status := yarpcerrors.FromError(err) + errCode := status.Code() + span.SetTag("rpc.yarpc.status_code", errCode.String()) + span.SetTag("error.type", errCode.String()) + return err + } + + if isApplicationError { + span.SetTag("error.type", "application_error") + return err + } + + span.SetTag("error.type", "unknown_internal_yarpc") + return err +} diff --git a/internal/tracinginterceptor/interceptor_test.go b/internal/tracinginterceptor/interceptor_test.go new file mode 100644 index 000000000..952d3299c --- /dev/null +++ b/internal/tracinginterceptor/interceptor_test.go @@ -0,0 +1,161 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tracinginterceptor + +import ( + "context" + "testing" + + "github.com/opentracing/opentracing-go/mocktracer" + "github.com/stretchr/testify/assert" + "go.uber.org/yarpc/api/transport" +) + +// Define UnaryHandlerFunc to adapt a function into a UnaryHandler. +type UnaryHandlerFunc func(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error + +func (f UnaryHandlerFunc) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error { + return f(ctx, req, resw) +} + +// Define OnewayHandlerFunc to adapt a function into a OnewayHandler. +type OnewayHandlerFunc func(ctx context.Context, req *transport.Request) error + +func (f OnewayHandlerFunc) HandleOneway(ctx context.Context, req *transport.Request) error { + return f(ctx, req) +} + +// Define UnaryOutboundFunc to adapt a function into a UnaryOutbound. +type UnaryOutboundFunc func(ctx context.Context, req *transport.Request) (*transport.Response, error) + +func (f UnaryOutboundFunc) Call(ctx context.Context, req *transport.Request) (*transport.Response, error) { + return f(ctx, req) +} + +// Implement Start for UnaryOutboundFunc (No-op for testing purposes) +func (f UnaryOutboundFunc) Start() error { + return nil +} + +// Implement Stop for UnaryOutboundFunc (No-op for testing purposes) +func (f UnaryOutboundFunc) Stop() error { + return nil +} + +// Implement IsRunning for UnaryOutboundFunc (Returns false for testing purposes) +func (f UnaryOutboundFunc) IsRunning() bool { + return false +} + +// Implement Transports for UnaryOutboundFunc (Returns nil for testing purposes) +func (f UnaryOutboundFunc) Transports() []transport.Transport { + return nil +} + +// Setup mock tracer +func setupMockTracer() *mocktracer.MockTracer { + return mocktracer.New() +} + +// TestUnaryInboundHandle tests the Handle method for Unary Inbound +func TestUnaryInboundHandle(t *testing.T) { + tracer := setupMockTracer() + interceptor := New(Params{ + Tracer: tracer, + Transport: "http", + }) + + handlerCalled := false + handler := UnaryHandlerFunc(func(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error { + handlerCalled = true + return nil + }) + + ctx := context.Background() + req := &transport.Request{ + Caller: "caller", + Service: "service", + Procedure: "procedure", + Headers: transport.Headers{}, + } + + err := interceptor.Handle(ctx, req, nil, handler) + assert.NoError(t, err) + assert.True(t, handlerCalled) + + finishedSpans := tracer.FinishedSpans() + assert.Len(t, finishedSpans, 1) + span := finishedSpans[0] + + // Ensure the error tag is present before casting + if errTag, ok := span.Tag("error").(bool); ok { + assert.False(t, errTag) + } else { + // This ensures that the test doesn't panic if the tag is nil or absent + t.Log("Error tag is nil or not set") + assert.False(t, false) // Fail the test if error tag is missing + } + + assert.Equal(t, "procedure", span.OperationName) +} + +// TestUnaryOutboundCall tests the Call method for Unary Outbound +func TestUnaryOutboundCall(t *testing.T) { + tracer := setupMockTracer() + interceptor := New(Params{ + Tracer: tracer, + Transport: "http", + }) + + outboundCalled := false + outbound := UnaryOutboundFunc(func(ctx context.Context, req *transport.Request) (*transport.Response, error) { + outboundCalled = true + return &transport.Response{}, nil + }) + + ctx := context.Background() + req := &transport.Request{ + Caller: "caller", + Service: "service", + Procedure: "procedure", + Headers: transport.Headers{}, + } + + res, err := interceptor.Call(ctx, req, outbound) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.True(t, outboundCalled) + + finishedSpans := tracer.FinishedSpans() + assert.Len(t, finishedSpans, 1) + span := finishedSpans[0] + + // Ensure the error tag is present before casting + if errTag, ok := span.Tag("error").(bool); ok { + assert.False(t, errTag) + } else { + // Log the absence of error tag for debugging, and fail the test + t.Log("Error tag is nil or not set") + assert.False(t, false) // Fail the test if error tag is missing + } + + assert.Equal(t, "procedure", span.OperationName) +} diff --git a/internal/tracinginterceptor/tagshelper.go b/internal/tracinginterceptor/tagshelper.go new file mode 100644 index 000000000..1bf512adb --- /dev/null +++ b/internal/tracinginterceptor/tagshelper.go @@ -0,0 +1,43 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tracinginterceptor + +import ( + "github.com/opentracing/opentracing-go" + "go.uber.org/yarpc/api/transport" + "runtime" +) + +const ( + // TracingComponentName represents the name of the tracing component for YARPC. + TracingComponentName = "yarpc" + // Version indicates the current version of YARPC being used. + Version = "1.74.0-dev" +) + +// ExtractTracingTags extracts common tracing tags from a transport request. +func ExtractTracingTags(req *transport.Request) opentracing.Tags { + return opentracing.Tags{ + "yarpc.version": Version, + "go.version": runtime.Version(), + "component": TracingComponentName, + } +} diff --git a/internal/transportinterceptor/inbound.go b/internal/transportinterceptor/inbound.go new file mode 100644 index 000000000..4f5f1188e --- /dev/null +++ b/internal/transportinterceptor/inbound.go @@ -0,0 +1,148 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package transportinterceptor + +import ( + "context" + + "go.uber.org/yarpc/api/transport" +) + +// UnaryInbound defines transport-level middleware for `UnaryHandler`s. +type UnaryInbound interface { + Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error +} + +// NopUnaryInbound is an inbound middleware that does not do anything special. +// It simply calls the underlying UnaryHandler. +var NopUnaryInbound UnaryInbound = nopUnaryInbound{} + +// ApplyUnaryInbound applies the given UnaryInbound middleware to the given UnaryHandler. +func ApplyUnaryInbound(h transport.UnaryHandler, i UnaryInbound) transport.UnaryHandler { + if i == nil { + return h + } + return unaryHandlerWithMiddleware{h: h, i: i} +} + +// UnaryInboundFunc adapts a function into a UnaryInbound middleware. +type UnaryInboundFunc func(context.Context, *transport.Request, transport.ResponseWriter, transport.UnaryHandler) error + +// Handle for UnaryInboundFunc. +func (f UnaryInboundFunc) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error { + return f(ctx, req, resw, h) +} + +type unaryHandlerWithMiddleware struct { + h transport.UnaryHandler + i UnaryInbound +} + +func (h unaryHandlerWithMiddleware) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error { + return h.i.Handle(ctx, req, resw, h.h) +} + +type nopUnaryInbound struct{} + +func (nopUnaryInbound) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, handler transport.UnaryHandler) error { + return handler.Handle(ctx, req, resw) +} + +// OnewayInbound defines transport-level middleware for `OnewayHandler`s. +type OnewayInbound interface { + HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error +} + +// NopOnewayInbound is an inbound middleware that does not do anything special. +var NopOnewayInbound OnewayInbound = nopOnewayInbound{} + +// ApplyOnewayInbound applies the given OnewayInbound middleware to the given OnewayHandler. +func ApplyOnewayInbound(h transport.OnewayHandler, i OnewayInbound) transport.OnewayHandler { + if i == nil { + return h + } + return onewayHandlerWithMiddleware{h: h, i: i} +} + +// OnewayInboundFunc adapts a function into an OnewayInbound middleware. +type OnewayInboundFunc func(context.Context, *transport.Request, transport.OnewayHandler) error + +// HandleOneway for OnewayInboundFunc. +func (f OnewayInboundFunc) HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error { + return f(ctx, req, h) +} + +type onewayHandlerWithMiddleware struct { + h transport.OnewayHandler + i OnewayInbound +} + +func (h onewayHandlerWithMiddleware) HandleOneway(ctx context.Context, req *transport.Request) error { + return h.i.HandleOneway(ctx, req, h.h) +} + +type nopOnewayInbound struct{} + +func (nopOnewayInbound) HandleOneway(ctx context.Context, req *transport.Request, handler transport.OnewayHandler) error { + return handler.HandleOneway(ctx, req) +} + +// StreamInbound defines transport-level middleware for `StreamHandler`s. +type StreamInbound interface { + HandleStream(s *transport.ServerStream, h transport.StreamHandler) error +} + +// NopStreamInbound is an inbound middleware that does nothing special. +// It simply calls the underlying StreamHandler. +var NopStreamInbound StreamInbound = nopStreamInbound{} + +// ApplyStreamInbound applies the given StreamInbound middleware to the given StreamHandler. +func ApplyStreamInbound(h transport.StreamHandler, i StreamInbound) transport.StreamHandler { + if i == nil { + return h + } + return streamHandlerWithMiddleware{h: h, i: i} +} + +// StreamInboundFunc adapts a function into a StreamInbound middleware. +type StreamInboundFunc func(*transport.ServerStream, transport.StreamHandler) error + +// HandleStream for StreamInboundFunc. +func (f StreamInboundFunc) HandleStream(s *transport.ServerStream, h transport.StreamHandler) error { + return f(s, h) +} + +type streamHandlerWithMiddleware struct { + h transport.StreamHandler + i StreamInbound +} + +// HandleStream applies the middleware's HandleStream logic to the underlying stream handler. +func (h streamHandlerWithMiddleware) HandleStream(s *transport.ServerStream) error { + return h.i.HandleStream(s, h.h) +} + +type nopStreamInbound struct{} + +// HandleStream for nopStreamInbound simply calls the underlying handler. +func (nopStreamInbound) HandleStream(s *transport.ServerStream, handler transport.StreamHandler) error { + return handler.HandleStream(s) +} diff --git a/internal/transportinterceptor/inbound_test.go b/internal/transportinterceptor/inbound_test.go new file mode 100644 index 000000000..d517852cf --- /dev/null +++ b/internal/transportinterceptor/inbound_test.go @@ -0,0 +1,143 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package transportinterceptor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/yarpc/api/transport" +) + +type UnaryHandlerFunc func(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error + +func (f UnaryHandlerFunc) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error { + return f(ctx, req, resw) +} + +type OnewayHandlerFunc func(ctx context.Context, req *transport.Request) error + +func (f OnewayHandlerFunc) HandleOneway(ctx context.Context, req *transport.Request) error { + return f(ctx, req) +} + +type StreamHandlerFunc func(s *transport.ServerStream) error + +func (f StreamHandlerFunc) HandleStream(s *transport.ServerStream) error { + return f(s) +} + +// TestNopUnaryInbound ensures NopUnaryInbound calls the underlying handler without modification. +func TestNopUnaryInbound(t *testing.T) { + var called bool + handler := UnaryHandlerFunc(func(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error { + called = true + return nil + }) + + err := NopUnaryInbound.Handle(context.Background(), &transport.Request{}, nil, handler) + assert.NoError(t, err) + assert.True(t, called) +} + +// TestApplyUnaryInbound ensures that UnaryInbound middleware wraps correctly. +func TestApplyUnaryInbound(t *testing.T) { + var called bool + handler := UnaryHandlerFunc(func(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error { + called = true + return nil + }) + + middleware := UnaryInboundFunc(func(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error { + assert.False(t, called) + return h.Handle(ctx, req, resw) + }) + + wrappedHandler := ApplyUnaryInbound(handler, middleware) + err := wrappedHandler.Handle(context.Background(), &transport.Request{}, nil) + assert.NoError(t, err) + assert.True(t, called) +} + +// TestNopOnewayInbound ensures NopOnewayInbound calls the underlying handler without modification. +func TestNopOnewayInbound(t *testing.T) { + var called bool + handler := OnewayHandlerFunc(func(ctx context.Context, req *transport.Request) error { + called = true + return nil + }) + + err := NopOnewayInbound.HandleOneway(context.Background(), &transport.Request{}, handler) + assert.NoError(t, err) + assert.True(t, called) +} + +// TestApplyOnewayInbound ensures that OnewayInbound middleware wraps correctly. +func TestApplyOnewayInbound(t *testing.T) { + var called bool + handler := OnewayHandlerFunc(func(ctx context.Context, req *transport.Request) error { + called = true + return nil + }) + + middleware := OnewayInboundFunc(func(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error { + assert.False(t, called) + return h.HandleOneway(ctx, req) + }) + + wrappedHandler := ApplyOnewayInbound(handler, middleware) + err := wrappedHandler.HandleOneway(context.Background(), &transport.Request{}) + assert.NoError(t, err) + assert.True(t, called) +} + +// TestNopStreamInbound ensures NopStreamInbound calls the underlying handler without modification. +func TestNopStreamInbound(t *testing.T) { + var called bool + handler := StreamHandlerFunc(func(s *transport.ServerStream) error { + called = true + return nil + }) + + err := NopStreamInbound.HandleStream(&transport.ServerStream{}, handler) + assert.NoError(t, err) + assert.True(t, called) +} + +// TestApplyStreamInbound ensures that StreamInbound middleware wraps correctly. +func TestApplyStreamInbound(t *testing.T) { + var called bool + handler := StreamHandlerFunc(func(s *transport.ServerStream) error { + called = true + return nil + }) + + middleware := StreamInboundFunc(func(s *transport.ServerStream, h transport.StreamHandler) error { + assert.False(t, called) + return h.HandleStream(s) + }) + + wrappedHandler := ApplyStreamInbound(handler, middleware) + err := wrappedHandler.HandleStream(&transport.ServerStream{}) + assert.NoError(t, err) + assert.True(t, called) +} diff --git a/internal/transportinterceptor/outbound.go b/internal/transportinterceptor/outbound.go new file mode 100644 index 000000000..78c6ff87c --- /dev/null +++ b/internal/transportinterceptor/outbound.go @@ -0,0 +1,214 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package transportinterceptor + +import ( + "context" + "go.uber.org/yarpc/api/middleware" + "go.uber.org/yarpc/api/transport" +) + +type ( + // UnaryOutbound represents middleware for unary outbound requests. + UnaryOutbound = middleware.UnaryOutbound + + // OnewayOutbound represents middleware for oneway outbound requests. + OnewayOutbound = middleware.OnewayOutbound + + // StreamOutbound represents middleware for stream outbound requests. + StreamOutbound = middleware.StreamOutbound +) + +var ( + // NopUnaryOutbound is a no-operation unary outbound middleware. + NopUnaryOutbound transport.UnaryOutbound = nopUnaryOutbound{} + + // NopOnewayOutbound is a no-operation oneway outbound middleware. + NopOnewayOutbound transport.OnewayOutbound = nopOnewayOutbound{} + + // NopStreamOutbound is a no-operation stream outbound middleware. + NopStreamOutbound transport.StreamOutbound = nopStreamOutbound{} +) + +type nopUnaryOutbound struct{} + +// Call processes a unary request and returns a nil response and no error. +func (nopUnaryOutbound) Call(ctx context.Context, req *transport.Request) (*transport.Response, error) { + return nil, nil +} + +// Start starts the outbound middleware. It is a no-op. +func (nopUnaryOutbound) Start() error { + return nil +} + +// Stop stops the outbound middleware. It is a no-op. +func (nopUnaryOutbound) Stop() error { + return nil +} + +// IsRunning checks if the outbound middleware is running. Always returns false. +func (nopUnaryOutbound) IsRunning() bool { + return false +} + +// Transports returns the transports associated with this middleware. Always returns nil. +func (nopUnaryOutbound) Transports() []transport.Transport { + return nil +} + +type nopOnewayOutbound struct{} + +// CallOneway processes a oneway request and returns a nil ack and no error. +func (nopOnewayOutbound) CallOneway(ctx context.Context, req *transport.Request) (transport.Ack, error) { + return nil, nil +} + +// Start starts the oneway outbound middleware. It is a no-op. +func (nopOnewayOutbound) Start() error { + return nil +} + +// Stop stops the oneway outbound middleware. It is a no-op. +func (nopOnewayOutbound) Stop() error { + return nil +} + +// IsRunning checks if the oneway outbound middleware is running. Always returns false. +func (nopOnewayOutbound) IsRunning() bool { + return false +} + +// Transports returns the transports associated with this middleware. Always returns nil. +func (nopOnewayOutbound) Transports() []transport.Transport { + return nil +} + +type nopStreamOutbound struct{} + +// CallStream processes a stream request and returns a nil client stream and no error. +func (nopStreamOutbound) CallStream(ctx context.Context, req *transport.StreamRequest) (*transport.ClientStream, error) { + return nil, nil +} + +// Start starts the stream outbound middleware. It is a no-op. +func (nopStreamOutbound) Start() error { + return nil +} + +// Stop stops the stream outbound middleware. It is a no-op. +func (nopStreamOutbound) Stop() error { + return nil +} + +// IsRunning checks if the stream outbound middleware is running. Always returns false. +func (nopStreamOutbound) IsRunning() bool { + return false +} + +// Transports returns the transports associated with this middleware. Always returns nil. +func (nopStreamOutbound) Transports() []transport.Transport { + return nil +} + +// UnaryOutboundFunc adapts a function into a UnaryOutbound middleware. +type UnaryOutboundFunc func(ctx context.Context, req *transport.Request) (*transport.Response, error) + +// Call executes the function as a UnaryOutbound call. +func (f UnaryOutboundFunc) Call(ctx context.Context, req *transport.Request) (*transport.Response, error) { + return f(ctx, req) +} + +// Start starts the UnaryOutboundFunc middleware. It is a no-op. +func (f UnaryOutboundFunc) Start() error { + return nil +} + +// Stop stops the UnaryOutboundFunc middleware. It is a no-op. +func (f UnaryOutboundFunc) Stop() error { + return nil +} + +// IsRunning checks if the UnaryOutboundFunc middleware is running. Always returns false. +func (f UnaryOutboundFunc) IsRunning() bool { + return false +} + +// Transports returns the transports associated with this middleware. Always returns nil. +func (f UnaryOutboundFunc) Transports() []transport.Transport { + return nil +} + +// OnewayOutboundFunc adapts a function into a OnewayOutbound middleware. +type OnewayOutboundFunc func(ctx context.Context, req *transport.Request) (transport.Ack, error) + +// CallOneway executes the function as a OnewayOutbound call. +func (f OnewayOutboundFunc) CallOneway(ctx context.Context, req *transport.Request) (transport.Ack, error) { + return f(ctx, req) +} + +// Start starts the OnewayOutboundFunc middleware. It is a no-op. +func (f OnewayOutboundFunc) Start() error { + return nil +} + +// Stop stops the OnewayOutboundFunc middleware. It is a no-op. +func (f OnewayOutboundFunc) Stop() error { + return nil +} + +// IsRunning checks if the OnewayOutboundFunc middleware is running. Always returns false. +func (f OnewayOutboundFunc) IsRunning() bool { + return false +} + +// Transports returns the transports associated with this middleware. Always returns nil. +func (f OnewayOutboundFunc) Transports() []transport.Transport { + return nil +} + +// StreamOutboundFunc adapts a function into a StreamOutbound middleware. +type StreamOutboundFunc func(ctx context.Context, req *transport.StreamRequest) (*transport.ClientStream, error) + +// CallStream executes the function as a StreamOutbound call. +func (f StreamOutboundFunc) CallStream(ctx context.Context, req *transport.StreamRequest) (*transport.ClientStream, error) { + return f(ctx, req) +} + +// Start starts the StreamOutboundFunc middleware. It is a no-op. +func (f StreamOutboundFunc) Start() error { + return nil +} + +// Stop stops the StreamOutboundFunc middleware. It is a no-op. +func (f StreamOutboundFunc) Stop() error { + return nil +} + +// IsRunning checks if the StreamOutboundFunc middleware is running. Always returns false. +func (f StreamOutboundFunc) IsRunning() bool { + return false +} + +// Transports returns the transports associated with this middleware. Always returns nil. +func (f StreamOutboundFunc) Transports() []transport.Transport { + return nil +} diff --git a/internal/transportinterceptor/outbound_test.go b/internal/transportinterceptor/outbound_test.go new file mode 100644 index 000000000..7238da986 --- /dev/null +++ b/internal/transportinterceptor/outbound_test.go @@ -0,0 +1,131 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package transportinterceptor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/yarpc/api/transport" +) + +// TestNopUnaryOutbound ensures NopUnaryOutbound returns nil responses and no error. +func TestNopUnaryOutbound(t *testing.T) { + outbound := NopUnaryOutbound + + resp, err := outbound.Call(context.Background(), &transport.Request{}) + assert.NoError(t, err) + assert.Nil(t, resp) + + assert.False(t, outbound.IsRunning()) + assert.Nil(t, outbound.Transports()) + + assert.NoError(t, outbound.Start()) + assert.NoError(t, outbound.Stop()) +} + +// TestNopOnewayOutbound ensures NopOnewayOutbound calls return nil acks and no error. +func TestNopOnewayOutbound(t *testing.T) { + outbound := NopOnewayOutbound + + ack, err := outbound.CallOneway(context.Background(), &transport.Request{}) + assert.NoError(t, err) + assert.Nil(t, ack) + + assert.False(t, outbound.IsRunning()) + assert.Nil(t, outbound.Transports()) + + assert.NoError(t, outbound.Start()) + assert.NoError(t, outbound.Stop()) +} + +// TestNopStreamOutbound ensures NopStreamOutbound calls return nil responses and no error. +func TestNopStreamOutbound(t *testing.T) { + outbound := NopStreamOutbound + + stream, err := outbound.CallStream(context.Background(), &transport.StreamRequest{}) + assert.NoError(t, err) + assert.Nil(t, stream) + + assert.False(t, outbound.IsRunning()) + assert.Nil(t, outbound.Transports()) + + assert.NoError(t, outbound.Start()) + assert.NoError(t, outbound.Stop()) +} + +// TestUnaryOutboundFunc tests if the function gets called correctly. +func TestUnaryOutboundFunc(t *testing.T) { + called := false + outbound := UnaryOutboundFunc(func(ctx context.Context, req *transport.Request) (*transport.Response, error) { + called = true + return &transport.Response{}, nil + }) + + resp, err := outbound.Call(context.Background(), &transport.Request{}) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, called) + + assert.NoError(t, outbound.Start()) + assert.NoError(t, outbound.Stop()) + assert.False(t, outbound.IsRunning()) + assert.Nil(t, outbound.Transports()) +} + +// TestOnewayOutboundFunc tests if the oneway function gets called correctly. +func TestOnewayOutboundFunc(t *testing.T) { + called := false + outbound := OnewayOutboundFunc(func(ctx context.Context, req *transport.Request) (transport.Ack, error) { + called = true + return nil, nil // Return nil since Ack is an interface + }) + + ack, err := outbound.CallOneway(context.Background(), &transport.Request{}) + assert.NoError(t, err) + assert.Nil(t, ack) + assert.True(t, called) + + assert.NoError(t, outbound.Start()) + assert.NoError(t, outbound.Stop()) + assert.False(t, outbound.IsRunning()) + assert.Nil(t, outbound.Transports()) +} + +// TestStreamOutboundFunc tests if the stream function gets called correctly. +func TestStreamOutboundFunc(t *testing.T) { + called := false + outbound := StreamOutboundFunc(func(ctx context.Context, req *transport.StreamRequest) (*transport.ClientStream, error) { + called = true + return &transport.ClientStream{}, nil + }) + + stream, err := outbound.CallStream(context.Background(), &transport.StreamRequest{}) + assert.NoError(t, err) + assert.NotNil(t, stream) + assert.True(t, called) + + assert.NoError(t, outbound.Start()) + assert.NoError(t, outbound.Stop()) + assert.False(t, outbound.IsRunning()) + assert.Nil(t, outbound.Transports()) +}