diff --git a/third-party/thrift/src/thrift/lib/cpp/ContextStack.cpp b/third-party/thrift/src/thrift/lib/cpp/ContextStack.cpp index 6cfbc099b2725..75900bdfbfe86 100644 --- a/third-party/thrift/src/thrift/lib/cpp/ContextStack.cpp +++ b/third-party/thrift/src/thrift/lib/cpp/ContextStack.cpp @@ -25,6 +25,22 @@ namespace apache { namespace thrift { +namespace { +const char* stripServiceNamePrefix( + const char* method, const char* serviceName) { + const char* unprefixed = method; + for (const char* c = serviceName; *c != '\0'; ++c) { + if (*unprefixed != *c) { + // The method name does not contain the service name as prefix + return method; + } + unprefixed++; + } + // Almost... missing the dot implies it does not count as a prefix + return *unprefixed == '.' ? unprefixed + 1 : method; +} +} // namespace + using util::AllocationColocator; class ContextStack::EmbeddedClientRequestContext @@ -45,14 +61,16 @@ ContextStack::ContextStack( TConnectionContext* connectionContext) : handlers_(handlers), serviceName_(serviceName), - method_(method), + methodNamePrefixed_(method), + methodNameUnprefixed_( + stripServiceNamePrefix(methodNamePrefixed_, serviceName_)), serviceContexts_(serviceContexts) { if (!handlers_ || handlers_->empty()) { return; } for (size_t i = 0; i < handlers_->size(); ++i) { contextAt(i) = (*handlers_)[i]->getServiceContext( - serviceName_, method_, connectionContext); + serviceName_, methodNamePrefixed_, connectionContext); } } @@ -81,7 +99,7 @@ ContextStack::ContextStack( ContextStack::~ContextStack() { if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->freeContext(contextAt(i), method_); + (*handlers_)[i]->freeContext(contextAt(i), methodNamePrefixed_); } } } @@ -212,52 +230,69 @@ ContextStack::UniquePtr ContextStack::createWithClientContextCopyNames( } void ContextStack::preWrite() { - FOLLY_SDT(thrift, thrift_context_stack_pre_write, serviceName_, method_); + FOLLY_SDT( + thrift, + thrift_context_stack_pre_write, + serviceName_, + methodNamePrefixed_); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->preWrite(contextAt(i), method_); + (*handlers_)[i]->preWrite(contextAt(i), methodNamePrefixed_); } } } void ContextStack::onWriteData(const SerializedMessage& msg) { - FOLLY_SDT(thrift, thrift_context_stack_on_write_data, serviceName_, method_); + FOLLY_SDT( + thrift, + thrift_context_stack_on_write_data, + serviceName_, + methodNamePrefixed_); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->onWriteData(contextAt(i), method_, msg); + (*handlers_)[i]->onWriteData(contextAt(i), methodNamePrefixed_, msg); } } } void ContextStack::postWrite(uint32_t bytes) { FOLLY_SDT( - thrift, thrift_context_stack_post_write, serviceName_, method_, bytes); + thrift, + thrift_context_stack_post_write, + serviceName_, + methodNamePrefixed_, + bytes); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->postWrite(contextAt(i), method_, bytes); + (*handlers_)[i]->postWrite(contextAt(i), methodNamePrefixed_, bytes); } } } void ContextStack::preRead() { - FOLLY_SDT(thrift, thrift_context_stack_pre_read, serviceName_, method_); + FOLLY_SDT( + thrift, thrift_context_stack_pre_read, serviceName_, methodNamePrefixed_); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->preRead(contextAt(i), method_); + (*handlers_)[i]->preRead(contextAt(i), methodNamePrefixed_); } } } void ContextStack::onReadData(const SerializedMessage& msg) { - FOLLY_SDT(thrift, thrift_context_stack_on_read_data, serviceName_, method_); + FOLLY_SDT( + thrift, + thrift_context_stack_on_read_data, + serviceName_, + methodNamePrefixed_); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->onReadData(contextAt(i), method_, msg); + (*handlers_)[i]->onReadData(contextAt(i), methodNamePrefixed_, msg); } } } @@ -265,11 +300,16 @@ void ContextStack::onReadData(const SerializedMessage& msg) { void ContextStack::postRead( apache::thrift::transport::THeader* header, uint32_t bytes) { FOLLY_SDT( - thrift, thrift_context_stack_post_read, serviceName_, method_, bytes); + thrift, + thrift_context_stack_post_read, + serviceName_, + methodNamePrefixed_, + bytes); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->postRead(contextAt(i), method_, header, bytes); + (*handlers_)[i]->postRead( + contextAt(i), methodNamePrefixed_, header, bytes); } } } @@ -290,11 +330,12 @@ void ContextStack::handlerErrorWrapped(const folly::exception_wrapper& ew) { thrift, thrift_context_stack_handler_error_wrapped, serviceName_, - method_); + methodNamePrefixed_); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { - (*handlers_)[i]->handlerErrorWrapped(contextAt(i), method_, ew); + (*handlers_)[i]->handlerErrorWrapped( + contextAt(i), methodNamePrefixed_, ew); } } } @@ -305,12 +346,12 @@ void ContextStack::userExceptionWrapped( thrift, thrift_context_stack_user_exception_wrapped, serviceName_, - method_); + methodNamePrefixed_); if (handlers_) { for (size_t i = 0; i < handlers_->size(); i++) { (*handlers_)[i]->userExceptionWrapped( - contextAt(i), method_, declared, ew); + contextAt(i), methodNamePrefixed_, declared, ew); } } } @@ -332,7 +373,9 @@ folly::Try ContextStack::processClientInterceptorsOnRequest() noexcept { for (std::size_t i = 0; i < clientInterceptors_->size(); ++i) { const auto& clientInterceptor = (*clientInterceptors_)[i]; ClientInterceptorBase::RequestInfo requestInfo{ - getStorageForClientInterceptorOnRequestByIndex(i)}; + getStorageForClientInterceptorOnRequestByIndex(i), + serviceName_, + methodNameUnprefixed_}; try { clientInterceptor->internal_onRequest(std::move(requestInfo)); } catch (...) { @@ -359,7 +402,9 @@ folly::Try ContextStack::processClientInterceptorsOnResponse() noexcept { for (auto i = std::ptrdiff_t(clientInterceptors_->size()) - 1; i >= 0; --i) { const auto& clientInterceptor = (*clientInterceptors_)[i]; ClientInterceptorBase::ResponseInfo responseInfo{ - getStorageForClientInterceptorOnRequestByIndex(i)}; + getStorageForClientInterceptorOnRequestByIndex(i), + serviceName_, + methodNameUnprefixed_}; try { clientInterceptor->internal_onResponse(std::move(responseInfo)); } catch (...) { diff --git a/third-party/thrift/src/thrift/lib/cpp/ContextStack.h b/third-party/thrift/src/thrift/lib/cpp/ContextStack.h index 5ffc372186c13..e9255aae0168c 100644 --- a/third-party/thrift/src/thrift/lib/cpp/ContextStack.h +++ b/third-party/thrift/src/thrift/lib/cpp/ContextStack.h @@ -113,8 +113,12 @@ class ContextStack { handlers_; std::shared_ptr>> clientInterceptors_; + // Must be NUL-terminated. const char* const serviceName_; - const char* const method_; + // "{service_name}.{method_name}" + const char* const methodNamePrefixed_; + // "{method_name}", without the service name prefix + const char* const methodNameUnprefixed_; void** serviceContexts_; // While the server-side has a Cpp2RequestContext, the client-side "fakes" it // with an embedded version. We can't make it nullptr because this is the API diff --git a/third-party/thrift/src/thrift/lib/cpp2/async/ClientInterceptorBase.h b/third-party/thrift/src/thrift/lib/cpp2/async/ClientInterceptorBase.h index 5106716d3335c..469a551e6faed 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/async/ClientInterceptorBase.h +++ b/third-party/thrift/src/thrift/lib/cpp2/async/ClientInterceptorBase.h @@ -33,11 +33,31 @@ class ClientInterceptorBase { struct RequestInfo { detail::ClientInterceptorOnRequestStorage* storage = nullptr; + /** + * The name of the service definition as specified in Thrift IDL. + */ + const char* serviceName = nullptr; + /** + * The name of the method as specified in Thrift IDL. This does NOT include + * the service name. If the method is an interaction method, then it will be + * in the format `{interaction_name}.{method_name}`. + */ + const char* methodName = nullptr; }; virtual void internal_onRequest(RequestInfo) = 0; struct ResponseInfo { detail::ClientInterceptorOnRequestStorage* storage = nullptr; + /** + * The name of the service definition as specified in Thrift IDL. + */ + const char* serviceName = nullptr; + /** + * The name of the method as specified in Thrift IDL. This does NOT include + * the service name. If the method is an interaction method, then it will be + * in the format `{interaction_name}.{method_name}`. + */ + const char* methodName = nullptr; }; virtual void internal_onResponse(ResponseInfo) = 0; }; diff --git a/third-party/thrift/src/thrift/lib/cpp2/async/tests/ClientInterceptorTest.cpp b/third-party/thrift/src/thrift/lib/cpp2/async/tests/ClientInterceptorTest.cpp index 41d891b176cce..2e24f180bf53e 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/async/tests/ClientInterceptorTest.cpp +++ b/third-party/thrift/src/thrift/lib/cpp2/async/tests/ClientInterceptorTest.cpp @@ -339,12 +339,39 @@ class ClientInterceptorThatThrowsOnResponse } }; +class TracingClientInterceptor : public NamedClientInterceptor { + public: + using NamedClientInterceptor::NamedClientInterceptor; + + using Trace = std::pair; + const std::vector& requests() const { return requests_; } + const std::vector& responses() const { return responses_; } + + std::optional onRequest(RequestInfo requestInfo) override { + requests_.push_back( + {std::string(requestInfo.serviceName), + std::string(requestInfo.methodName)}); + return folly::unit; + } + + void onResponse(folly::Unit*, ResponseInfo responseInfo) override { + responses_.push_back( + {std::string(responseInfo.serviceName), + std::string(responseInfo.methodName)}); + } + + private: + std::vector requests_; + std::vector responses_; +}; + } // namespace CO_TEST_P(ClientInterceptorTestP, Basic) { auto interceptor = std::make_shared("Interceptor1"); - auto client = makeClient(makeInterceptorsList(interceptor)); + auto tracer = std::make_shared("Tracer"); + auto client = makeClient(makeInterceptorsList(interceptor, tracer)); co_await client->echo("foo"); EXPECT_EQ(interceptor->onRequestCount, 1); @@ -353,6 +380,14 @@ CO_TEST_P(ClientInterceptorTestP, Basic) { co_await client->noop(); EXPECT_EQ(interceptor->onRequestCount, 2); EXPECT_EQ(interceptor->onResponseCount, 2); + + using Trace = TracingClientInterceptor::Trace; + const std::vector expectedTrace{ + Trace{"ClientInterceptorTest", "echo"}, + Trace{"ClientInterceptorTest", "noop"}, + }; + EXPECT_THAT(tracer->requests(), ElementsAreArray(expectedTrace)); + EXPECT_THAT(tracer->responses(), ElementsAreArray(expectedTrace)); } CO_TEST_P(ClientInterceptorTestP, OnRequestException) { @@ -362,8 +397,9 @@ CO_TEST_P(ClientInterceptorTestP, OnRequestException) { std::make_shared("Interceptor2"); auto interceptor3 = std::make_shared("Interceptor3"); + auto tracer = std::make_shared("Tracer"); auto client = makeClient( - makeInterceptorsList(interceptor1, interceptor2, interceptor3)); + makeInterceptorsList(interceptor1, interceptor2, interceptor3, tracer)); EXPECT_THROW( { @@ -387,6 +423,11 @@ CO_TEST_P(ClientInterceptorTestP, OnRequestException) { EXPECT_EQ(interceptor1->onResponseCount, 0); EXPECT_EQ(interceptor2->onResponseCount, 0); EXPECT_EQ(interceptor3->onResponseCount, 0); + + using Trace = TracingClientInterceptor::Trace; + EXPECT_THAT( + tracer->requests(), ElementsAre(Trace{"ClientInterceptorTest", "noop"})); + EXPECT_THAT(tracer->responses(), IsEmpty()); } CO_TEST_P(ClientInterceptorTestP, IterationOrder) { @@ -556,7 +597,9 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) { auto interceptor2 = std::make_shared( "Interceptor2"); - auto client = makeClient(makeInterceptorsList(interceptor1, interceptor2)); + auto tracer = std::make_shared("Tracer"); + auto client = + makeClient(makeInterceptorsList(interceptor1, interceptor2, tracer)); { auto interaction = co_await client->createInteraction(); @@ -582,6 +625,15 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) { EXPECT_EQ(interceptor->onRequestCount, 3); EXPECT_EQ(interceptor->onResponseCount, 3); } + + using Trace = TracingClientInterceptor::Trace; + const std::vector expectedTrace{ + Trace{"ClientInterceptorTest", "createInteraction"}, + Trace{"ClientInterceptorTest", "SampleInteraction.echo"}, + Trace{"ClientInterceptorTest", "echo"}, + }; + EXPECT_THAT(tracer->requests(), ElementsAreArray(expectedTrace)); + EXPECT_THAT(tracer->responses(), ElementsAreArray(expectedTrace)); } // With initial response @@ -592,7 +644,9 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) { auto interceptor2 = std::make_shared( "Interceptor2"); - auto client = makeClient(makeInterceptorsList(interceptor1, interceptor2)); + auto tracer = std::make_shared("Tracer"); + auto client = + makeClient(makeInterceptorsList(interceptor1, interceptor2, tracer)); { auto [interaction, response] = co_await client->createInteractionAndEcho("hello"); @@ -618,6 +672,15 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) { EXPECT_EQ(interceptor->onRequestCount, 3); EXPECT_EQ(interceptor->onResponseCount, 3); } + + using Trace = TracingClientInterceptor::Trace; + const std::vector expectedTrace{ + Trace{"ClientInterceptorTest", "createInteractionAndEcho"}, + Trace{"ClientInterceptorTest", "SampleInteraction.echo"}, + Trace{"ClientInterceptorTest", "echo"}, + }; + EXPECT_THAT(tracer->requests(), ElementsAreArray(expectedTrace)); + EXPECT_THAT(tracer->responses(), ElementsAreArray(expectedTrace)); } }