Skip to content

Commit

Permalink
Pass service and method names to ClientInterceptor
Browse files Browse the repository at this point in the history
Summary: This matches `ServiceInterceptor`. `TProcessorEventHandler` passes in the service name as a prefix to method name, which is surprising. We make the decision to match `ServiceInterceptor` here for consistency.

Reviewed By: sazonovkirill

Differential Revision: D62663049

fbshipit-source-id: 686698e05a53a48b7a73d14b3585187d4d957689
  • Loading branch information
praihan authored and facebook-github-bot committed Sep 19, 2024
1 parent a0cb0a4 commit ff71d94
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 26 deletions.
87 changes: 66 additions & 21 deletions third-party/thrift/src/thrift/lib/cpp/ContextStack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@
namespace apache {
namespace thrift {

namespace {
const char* stripServiceNamePrefix(
const char* method, const char* serviceName) {
const char* unprefixed = method;
for (const char* c = serviceName; *c != '\0'; ++c) {
if (*unprefixed != *c) {
// The method name does not contain the service name as prefix
return method;
}
unprefixed++;
}
// Almost... missing the dot implies it does not count as a prefix
return *unprefixed == '.' ? unprefixed + 1 : method;
}
} // namespace

using util::AllocationColocator;

class ContextStack::EmbeddedClientRequestContext
Expand All @@ -45,14 +61,16 @@ ContextStack::ContextStack(
TConnectionContext* connectionContext)
: handlers_(handlers),
serviceName_(serviceName),
method_(method),
methodNamePrefixed_(method),
methodNameUnprefixed_(
stripServiceNamePrefix(methodNamePrefixed_, serviceName_)),
serviceContexts_(serviceContexts) {
if (!handlers_ || handlers_->empty()) {
return;
}
for (size_t i = 0; i < handlers_->size(); ++i) {
contextAt(i) = (*handlers_)[i]->getServiceContext(
serviceName_, method_, connectionContext);
serviceName_, methodNamePrefixed_, connectionContext);
}
}

Expand Down Expand Up @@ -81,7 +99,7 @@ ContextStack::ContextStack(
ContextStack::~ContextStack() {
if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->freeContext(contextAt(i), method_);
(*handlers_)[i]->freeContext(contextAt(i), methodNamePrefixed_);
}
}
}
Expand Down Expand Up @@ -212,64 +230,86 @@ ContextStack::UniquePtr ContextStack::createWithClientContextCopyNames(
}

void ContextStack::preWrite() {
FOLLY_SDT(thrift, thrift_context_stack_pre_write, serviceName_, method_);
FOLLY_SDT(
thrift,
thrift_context_stack_pre_write,
serviceName_,
methodNamePrefixed_);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->preWrite(contextAt(i), method_);
(*handlers_)[i]->preWrite(contextAt(i), methodNamePrefixed_);
}
}
}

void ContextStack::onWriteData(const SerializedMessage& msg) {
FOLLY_SDT(thrift, thrift_context_stack_on_write_data, serviceName_, method_);
FOLLY_SDT(
thrift,
thrift_context_stack_on_write_data,
serviceName_,
methodNamePrefixed_);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->onWriteData(contextAt(i), method_, msg);
(*handlers_)[i]->onWriteData(contextAt(i), methodNamePrefixed_, msg);
}
}
}

void ContextStack::postWrite(uint32_t bytes) {
FOLLY_SDT(
thrift, thrift_context_stack_post_write, serviceName_, method_, bytes);
thrift,
thrift_context_stack_post_write,
serviceName_,
methodNamePrefixed_,
bytes);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->postWrite(contextAt(i), method_, bytes);
(*handlers_)[i]->postWrite(contextAt(i), methodNamePrefixed_, bytes);
}
}
}

void ContextStack::preRead() {
FOLLY_SDT(thrift, thrift_context_stack_pre_read, serviceName_, method_);
FOLLY_SDT(
thrift, thrift_context_stack_pre_read, serviceName_, methodNamePrefixed_);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->preRead(contextAt(i), method_);
(*handlers_)[i]->preRead(contextAt(i), methodNamePrefixed_);
}
}
}

void ContextStack::onReadData(const SerializedMessage& msg) {
FOLLY_SDT(thrift, thrift_context_stack_on_read_data, serviceName_, method_);
FOLLY_SDT(
thrift,
thrift_context_stack_on_read_data,
serviceName_,
methodNamePrefixed_);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->onReadData(contextAt(i), method_, msg);
(*handlers_)[i]->onReadData(contextAt(i), methodNamePrefixed_, msg);
}
}
}

void ContextStack::postRead(
apache::thrift::transport::THeader* header, uint32_t bytes) {
FOLLY_SDT(
thrift, thrift_context_stack_post_read, serviceName_, method_, bytes);
thrift,
thrift_context_stack_post_read,
serviceName_,
methodNamePrefixed_,
bytes);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->postRead(contextAt(i), method_, header, bytes);
(*handlers_)[i]->postRead(
contextAt(i), methodNamePrefixed_, header, bytes);
}
}
}
Expand All @@ -290,11 +330,12 @@ void ContextStack::handlerErrorWrapped(const folly::exception_wrapper& ew) {
thrift,
thrift_context_stack_handler_error_wrapped,
serviceName_,
method_);
methodNamePrefixed_);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->handlerErrorWrapped(contextAt(i), method_, ew);
(*handlers_)[i]->handlerErrorWrapped(
contextAt(i), methodNamePrefixed_, ew);
}
}
}
Expand All @@ -305,12 +346,12 @@ void ContextStack::userExceptionWrapped(
thrift,
thrift_context_stack_user_exception_wrapped,
serviceName_,
method_);
methodNamePrefixed_);

if (handlers_) {
for (size_t i = 0; i < handlers_->size(); i++) {
(*handlers_)[i]->userExceptionWrapped(
contextAt(i), method_, declared, ew);
contextAt(i), methodNamePrefixed_, declared, ew);
}
}
}
Expand All @@ -332,7 +373,9 @@ folly::Try<void> ContextStack::processClientInterceptorsOnRequest() noexcept {
for (std::size_t i = 0; i < clientInterceptors_->size(); ++i) {
const auto& clientInterceptor = (*clientInterceptors_)[i];
ClientInterceptorBase::RequestInfo requestInfo{
getStorageForClientInterceptorOnRequestByIndex(i)};
getStorageForClientInterceptorOnRequestByIndex(i),
serviceName_,
methodNameUnprefixed_};
try {
clientInterceptor->internal_onRequest(std::move(requestInfo));
} catch (...) {
Expand All @@ -359,7 +402,9 @@ folly::Try<void> ContextStack::processClientInterceptorsOnResponse() noexcept {
for (auto i = std::ptrdiff_t(clientInterceptors_->size()) - 1; i >= 0; --i) {
const auto& clientInterceptor = (*clientInterceptors_)[i];
ClientInterceptorBase::ResponseInfo responseInfo{
getStorageForClientInterceptorOnRequestByIndex(i)};
getStorageForClientInterceptorOnRequestByIndex(i),
serviceName_,
methodNameUnprefixed_};
try {
clientInterceptor->internal_onResponse(std::move(responseInfo));
} catch (...) {
Expand Down
6 changes: 5 additions & 1 deletion third-party/thrift/src/thrift/lib/cpp/ContextStack.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,12 @@ class ContextStack {
handlers_;
std::shared_ptr<std::vector<std::shared_ptr<ClientInterceptorBase>>>
clientInterceptors_;
// Must be NUL-terminated.
const char* const serviceName_;
const char* const method_;
// "{service_name}.{method_name}"
const char* const methodNamePrefixed_;
// "{method_name}", without the service name prefix
const char* const methodNameUnprefixed_;
void** serviceContexts_;
// While the server-side has a Cpp2RequestContext, the client-side "fakes" it
// with an embedded version. We can't make it nullptr because this is the API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,31 @@ class ClientInterceptorBase {

struct RequestInfo {
detail::ClientInterceptorOnRequestStorage* storage = nullptr;
/**
* The name of the service definition as specified in Thrift IDL.
*/
const char* serviceName = nullptr;
/**
* The name of the method as specified in Thrift IDL. This does NOT include
* the service name. If the method is an interaction method, then it will be
* in the format `{interaction_name}.{method_name}`.
*/
const char* methodName = nullptr;
};
virtual void internal_onRequest(RequestInfo) = 0;

struct ResponseInfo {
detail::ClientInterceptorOnRequestStorage* storage = nullptr;
/**
* The name of the service definition as specified in Thrift IDL.
*/
const char* serviceName = nullptr;
/**
* The name of the method as specified in Thrift IDL. This does NOT include
* the service name. If the method is an interaction method, then it will be
* in the format `{interaction_name}.{method_name}`.
*/
const char* methodName = nullptr;
};
virtual void internal_onResponse(ResponseInfo) = 0;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,39 @@ class ClientInterceptorThatThrowsOnResponse
}
};

class TracingClientInterceptor : public NamedClientInterceptor<folly::Unit> {
public:
using NamedClientInterceptor::NamedClientInterceptor;

using Trace = std::pair<std::string, std::string>;
const std::vector<Trace>& requests() const { return requests_; }
const std::vector<Trace>& responses() const { return responses_; }

std::optional<folly::Unit> onRequest(RequestInfo requestInfo) override {
requests_.push_back(
{std::string(requestInfo.serviceName),
std::string(requestInfo.methodName)});
return folly::unit;
}

void onResponse(folly::Unit*, ResponseInfo responseInfo) override {
responses_.push_back(
{std::string(responseInfo.serviceName),
std::string(responseInfo.methodName)});
}

private:
std::vector<Trace> requests_;
std::vector<Trace> responses_;
};

} // namespace

CO_TEST_P(ClientInterceptorTestP, Basic) {
auto interceptor =
std::make_shared<ClientInterceptorCountWithRequestState>("Interceptor1");
auto client = makeClient(makeInterceptorsList(interceptor));
auto tracer = std::make_shared<TracingClientInterceptor>("Tracer");
auto client = makeClient(makeInterceptorsList(interceptor, tracer));

co_await client->echo("foo");
EXPECT_EQ(interceptor->onRequestCount, 1);
Expand All @@ -353,6 +380,14 @@ CO_TEST_P(ClientInterceptorTestP, Basic) {
co_await client->noop();
EXPECT_EQ(interceptor->onRequestCount, 2);
EXPECT_EQ(interceptor->onResponseCount, 2);

using Trace = TracingClientInterceptor::Trace;
const std::vector<Trace> expectedTrace{
Trace{"ClientInterceptorTest", "echo"},
Trace{"ClientInterceptorTest", "noop"},
};
EXPECT_THAT(tracer->requests(), ElementsAreArray(expectedTrace));
EXPECT_THAT(tracer->responses(), ElementsAreArray(expectedTrace));
}

CO_TEST_P(ClientInterceptorTestP, OnRequestException) {
Expand All @@ -362,8 +397,9 @@ CO_TEST_P(ClientInterceptorTestP, OnRequestException) {
std::make_shared<ClientInterceptorCountWithRequestState>("Interceptor2");
auto interceptor3 =
std::make_shared<ClientInterceptorThatThrowsOnRequest>("Interceptor3");
auto tracer = std::make_shared<TracingClientInterceptor>("Tracer");
auto client = makeClient(
makeInterceptorsList(interceptor1, interceptor2, interceptor3));
makeInterceptorsList(interceptor1, interceptor2, interceptor3, tracer));

EXPECT_THROW(
{
Expand All @@ -387,6 +423,11 @@ CO_TEST_P(ClientInterceptorTestP, OnRequestException) {
EXPECT_EQ(interceptor1->onResponseCount, 0);
EXPECT_EQ(interceptor2->onResponseCount, 0);
EXPECT_EQ(interceptor3->onResponseCount, 0);

using Trace = TracingClientInterceptor::Trace;
EXPECT_THAT(
tracer->requests(), ElementsAre(Trace{"ClientInterceptorTest", "noop"}));
EXPECT_THAT(tracer->responses(), IsEmpty());
}

CO_TEST_P(ClientInterceptorTestP, IterationOrder) {
Expand Down Expand Up @@ -556,7 +597,9 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) {
auto interceptor2 =
std::make_shared<ClientInterceptorCountWithRequestState>(
"Interceptor2");
auto client = makeClient(makeInterceptorsList(interceptor1, interceptor2));
auto tracer = std::make_shared<TracingClientInterceptor>("Tracer");
auto client =
makeClient(makeInterceptorsList(interceptor1, interceptor2, tracer));

{
auto interaction = co_await client->createInteraction();
Expand All @@ -582,6 +625,15 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) {
EXPECT_EQ(interceptor->onRequestCount, 3);
EXPECT_EQ(interceptor->onResponseCount, 3);
}

using Trace = TracingClientInterceptor::Trace;
const std::vector<Trace> expectedTrace{
Trace{"ClientInterceptorTest", "createInteraction"},
Trace{"ClientInterceptorTest", "SampleInteraction.echo"},
Trace{"ClientInterceptorTest", "echo"},
};
EXPECT_THAT(tracer->requests(), ElementsAreArray(expectedTrace));
EXPECT_THAT(tracer->responses(), ElementsAreArray(expectedTrace));
}

// With initial response
Expand All @@ -592,7 +644,9 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) {
auto interceptor2 =
std::make_shared<ClientInterceptorCountWithRequestState>(
"Interceptor2");
auto client = makeClient(makeInterceptorsList(interceptor1, interceptor2));
auto tracer = std::make_shared<TracingClientInterceptor>("Tracer");
auto client =
makeClient(makeInterceptorsList(interceptor1, interceptor2, tracer));
{
auto [interaction, response] =
co_await client->createInteractionAndEcho("hello");
Expand All @@ -618,6 +672,15 @@ CO_TEST_P(ClientInterceptorTestP, BasicInteraction) {
EXPECT_EQ(interceptor->onRequestCount, 3);
EXPECT_EQ(interceptor->onResponseCount, 3);
}

using Trace = TracingClientInterceptor::Trace;
const std::vector<Trace> expectedTrace{
Trace{"ClientInterceptorTest", "createInteractionAndEcho"},
Trace{"ClientInterceptorTest", "SampleInteraction.echo"},
Trace{"ClientInterceptorTest", "echo"},
};
EXPECT_THAT(tracer->requests(), ElementsAreArray(expectedTrace));
EXPECT_THAT(tracer->responses(), ElementsAreArray(expectedTrace));
}
}

Expand Down

0 comments on commit ff71d94

Please sign in to comment.