diff --git a/transport/grpc/client.go b/transport/grpc/client.go index edc57d3efef..5182b180dc4 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -11,6 +11,7 @@ import ( grpcinsecure "google.golang.org/grpc/credentials/insecure" grpcmd "google.golang.org/grpc/metadata" + "github.com/go-kratos/kratos/v2/internal/matcher" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" @@ -132,6 +133,7 @@ type clientOptions struct { timeout time.Duration discovery registry.Discovery middleware []middleware.Middleware + streamMiddleware []middleware.Middleware ints []grpc.UnaryClientInterceptor streamInts []grpc.StreamClientInterceptor grpcOpts []grpc.DialOption @@ -166,7 +168,7 @@ func dial(ctx context.Context, insecure bool, opts ...ClientOption) (*grpc.Clien unaryClientInterceptor(options.middleware, options.timeout, options.filters), } sints := []grpc.StreamClientInterceptor{ - streamClientInterceptor(options.filters), + streamClientInterceptor(options.streamMiddleware, options.filters), } if len(options.ints) > 0 { @@ -239,7 +241,54 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration, f } } -func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInterceptor { +// wrappedClientStream wraps the grpc.ClientStream and applies middleware +type wrappedClientStream struct { + grpc.ClientStream + ctx context.Context + middleware matcher.Matcher +} + +func (w *wrappedClientStream) Context() context.Context { + return w.ctx +} + +func (w *wrappedClientStream) SendMsg(m interface{}) error { + h := func(ctx context.Context, req interface{}) (interface{}, error) { + return req, w.ClientStream.SendMsg(m) + } + + info, ok := transport.FromClientContext(w.ctx) + if !ok { + return fmt.Errorf("transport value stored in ctx returns: %v", ok) + } + + if next := w.middleware.Match(info.Operation()); len(next) > 0 { + h = middleware.Chain(next...)(h) + } + + _, err := h(w.ctx, m) + return err +} + +func (w *wrappedClientStream) RecvMsg(m interface{}) error { + h := func(ctx context.Context, req interface{}) (interface{}, error) { + return req, w.ClientStream.RecvMsg(m) + } + + info, ok := transport.FromClientContext(w.ctx) + if !ok { + return fmt.Errorf("transport value stored in ctx returns: %v", ok) + } + + if next := w.middleware.Match(info.Operation()); len(next) > 0 { + h = middleware.Chain(next...)(h) + } + + _, err := h(w.ctx, m) + return err +} + +func streamClientInterceptor(ms []middleware.Middleware, filters []selector.NodeFilter) grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // nolint ctx = transport.NewClientContext(ctx, &Transport{ endpoint: cc.Target(), @@ -249,6 +298,28 @@ func streamClientInterceptor(filters []selector.NodeFilter) grpc.StreamClientInt }) var p selector.Peer ctx = selector.NewPeerContext(ctx, &p) - return streamer(ctx, desc, cc, method, opts...) + + clientStream, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, err + } + + h := func(ctx context.Context, req interface{}) (interface{}, error) { + return streamer, nil + } + + m := matcher.New() + if len(ms) > 0 { + m.Use(ms...) + middleware.Chain(ms...)(h) + } + + wrappedStream := &wrappedClientStream{ + ClientStream: clientStream, + ctx: ctx, + middleware: m, + } + + return wrappedStream, nil } } diff --git a/transport/grpc/interceptor.go b/transport/grpc/interceptor.go index 6cc331547c6..6261442d6a3 100644 --- a/transport/grpc/interceptor.go +++ b/transport/grpc/interceptor.go @@ -2,11 +2,13 @@ package grpc import ( "context" + "fmt" "google.golang.org/grpc" grpcmd "google.golang.org/grpc/metadata" ic "github.com/go-kratos/kratos/v2/internal/context" + "github.com/go-kratos/kratos/v2/internal/matcher" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" ) @@ -48,13 +50,15 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { // wrappedStream is rewrite grpc stream's context type wrappedStream struct { grpc.ServerStream - ctx context.Context + ctx context.Context + middleware matcher.Matcher } -func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream { +func NewWrappedStream(ctx context.Context, stream grpc.ServerStream, m matcher.Matcher) grpc.ServerStream { return &wrappedStream{ ServerStream: stream, ctx: ctx, + middleware: m, } } @@ -76,7 +80,19 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { replyHeader: headerCarrier(replyHeader), }) - ws := NewWrappedStream(ctx, ss) + h := func(ctx context.Context, req interface{}) (interface{}, error) { + return handler(srv, ss), nil + } + + if next := s.streamMiddleware.Match(info.FullMethod); len(next) > 0 { + middleware.Chain(next...)(h) + } + + ctx = context.WithValue(ctx, stream{ + ServerStream: ss, + streamMiddleware: s.streamMiddleware, + }, ss) + ws := NewWrappedStream(ctx, ss, s.streamMiddleware) err := handler(srv, ws) if len(replyHeader) > 0 { @@ -85,3 +101,48 @@ func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor { return err } } + +type stream struct { + grpc.ServerStream + streamMiddleware matcher.Matcher +} + +func GetStream(ctx context.Context) grpc.ServerStream { + return ctx.Value(stream{}).(grpc.ServerStream) +} + +func (w *wrappedStream) SendMsg(m interface{}) error { + h := func(_ context.Context, req interface{}) (interface{}, error) { + return req, w.ServerStream.SendMsg(m) + } + + info, ok := transport.FromServerContext(w.ctx) + if !ok { + return fmt.Errorf("transport value stored in ctx returns: %v", ok) + } + + if next := w.middleware.Match(info.Operation()); len(next) > 0 { + h = middleware.Chain(next...)(h) + } + + _, err := h(w.ctx, m) + return err +} + +func (w *wrappedStream) RecvMsg(m interface{}) error { + h := func(_ context.Context, req interface{}) (interface{}, error) { + return req, w.ServerStream.RecvMsg(m) + } + + info, ok := transport.FromServerContext(w.ctx) + if !ok { + return fmt.Errorf("transport value stored in ctx returns: %v", ok) + } + + if next := w.middleware.Match(info.Operation()); len(next) > 0 { + h = middleware.Chain(next...)(h) + } + + _, err := h(w.ctx, m) + return err +} diff --git a/transport/grpc/server.go b/transport/grpc/server.go index c8c3f4f3bcc..a1d74ace532 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -72,6 +72,12 @@ func Middleware(m ...middleware.Middleware) ServerOption { } } +func StreamMiddleware(m ...middleware.Middleware) ServerOption { + return func(s *Server) { + s.streamMiddleware.Use(m...) + } +} + // CustomHealth Checks server. func CustomHealth() ServerOption { return func(s *Server) { @@ -117,33 +123,35 @@ func Options(opts ...grpc.ServerOption) ServerOption { // Server is a gRPC server wrapper. type Server struct { *grpc.Server - baseCtx context.Context - tlsConf *tls.Config - lis net.Listener - err error - network string - address string - endpoint *url.URL - timeout time.Duration - middleware matcher.Matcher - unaryInts []grpc.UnaryServerInterceptor - streamInts []grpc.StreamServerInterceptor - grpcOpts []grpc.ServerOption - health *health.Server - customHealth bool - metadata *apimd.Server - adminClean func() + baseCtx context.Context + tlsConf *tls.Config + lis net.Listener + err error + network string + address string + endpoint *url.URL + timeout time.Duration + middleware matcher.Matcher + streamMiddleware matcher.Matcher + unaryInts []grpc.UnaryServerInterceptor + streamInts []grpc.StreamServerInterceptor + grpcOpts []grpc.ServerOption + health *health.Server + customHealth bool + metadata *apimd.Server + adminClean func() } // NewServer creates a gRPC server by options. func NewServer(opts ...ServerOption) *Server { srv := &Server{ - baseCtx: context.Background(), - network: "tcp", - address: ":0", - timeout: 1 * time.Second, - health: health.NewServer(), - middleware: matcher.New(), + baseCtx: context.Background(), + network: "tcp", + address: ":0", + timeout: 1 * time.Second, + health: health.NewServer(), + middleware: matcher.New(), + streamMiddleware: matcher.New(), } for _, o := range opts { o(srv) diff --git a/transport/grpc/server_test.go b/transport/grpc/server_test.go index d9d2c7af376..067c48596e1 100644 --- a/transport/grpc/server_test.go +++ b/transport/grpc/server_test.go @@ -12,6 +12,7 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/internal/matcher" @@ -280,6 +281,82 @@ func TestServer_unaryServerInterceptor(t *testing.T) { } } +type mockServerStream struct { + ctx context.Context + sentMsg interface{} + recvMsg interface{} + metadata metadata.MD + grpc.ServerStream +} + +func (m *mockServerStream) SetHeader(md metadata.MD) error { + m.metadata = md + return nil +} + +func (m *mockServerStream) SendHeader(md metadata.MD) error { + m.metadata = md + return nil +} + +func (m *mockServerStream) SetTrailer(md metadata.MD) { + m.metadata = md +} + +func (m *mockServerStream) Context() context.Context { + return m.ctx +} + +func (m *mockServerStream) SendMsg(msg interface{}) error { + m.sentMsg = msg + return nil +} + +func (m *mockServerStream) RecvMsg(msg interface{}) error { + m.recvMsg = msg + return nil +} + +func TestServer_streamServerInterceptor(t *testing.T) { + u, err := url.Parse("grpc://hello/world") + if err != nil { + t.Errorf("expect %v, got %v", nil, err) + } + srv := &Server{ + baseCtx: context.Background(), + endpoint: u, + timeout: time.Duration(10), + middleware: matcher.New(), + streamMiddleware: matcher.New(), + } + + srv.streamMiddleware.Use(EmptyMiddleware()) + + mockStream := &mockServerStream{ + ctx: srv.baseCtx, + } + + handler := func(_ interface{}, stream grpc.ServerStream) error { + resp := &testResp{Data: "stream hi"} + return stream.SendMsg(resp) + } + + info := &grpc.StreamServerInfo{ + FullMethod: "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", + } + + err = srv.streamServerInterceptor()(nil, mockStream, info, handler) + if err != nil { + t.Errorf("expect %v, got %v", nil, err) + } + + // Check response + resp := mockStream.sentMsg.(*testResp) + if !reflect.DeepEqual("stream hi", resp.Data) { + t.Errorf("expect %s, got %s", "stream hi", resp.Data) + } +} + func TestListener(t *testing.T) { lis, err := net.Listen("tcp", ":0") if err != nil {