//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Utilities to support construction of simple RPC APIs. // // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ // programmers, high performance, low memory overhead, and efficient use of the // communications channel. // //===----------------------------------------------------------------------===// #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H #define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H #include #include #include #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/Shared/OrcError.h" #include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" #include "llvm/Support/MSVCErrorWorkarounds.h" #include namespace llvm { namespace orc { namespace shared { /// Base class of all fatal RPC errors (those that necessarily result in the /// termination of the RPC session). class RPCFatalError : public ErrorInfo { public: static char ID; }; /// RPCConnectionClosed is returned from RPC operations if the RPC connection /// has already been closed due to either an error or graceful disconnection. class ConnectionClosed : public ErrorInfo { public: static char ID; std::error_code convertToErrorCode() const override; void log(raw_ostream &OS) const override; }; /// BadFunctionCall is returned from handleOne when the remote makes a call with /// an unrecognized function id. /// /// This error is fatal because Orc RPC needs to know how to parse a function /// call to know where the next call starts, and if it doesn't recognize the /// function id it cannot parse the call. template class BadFunctionCall : public ErrorInfo, RPCFatalError> { public: static char ID; BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} std::error_code convertToErrorCode() const override { return orcError(OrcErrorCode::UnexpectedRPCCall); } void log(raw_ostream &OS) const override { OS << "Call to invalid RPC function id '" << FnId << "' with " "sequence number " << SeqNo; } private: FnIdT FnId; SeqNoT SeqNo; }; template char BadFunctionCall::ID = 0; /// InvalidSequenceNumberForResponse is returned from handleOne when a response /// call arrives with a sequence number that doesn't correspond to any in-flight /// function call. /// /// This error is fatal because Orc RPC needs to know how to parse the rest of /// the response call to know where the next call starts, and if it doesn't have /// a result parser for this sequence number it can't do that. template class InvalidSequenceNumberForResponse : public ErrorInfo, RPCFatalError> { public: static char ID; InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {} std::error_code convertToErrorCode() const override { return orcError(OrcErrorCode::UnexpectedRPCCall); }; void log(raw_ostream &OS) const override { OS << "Response has unknown sequence number " << SeqNo; } private: SeqNoT SeqNo; }; template char InvalidSequenceNumberForResponse::ID = 0; /// This non-fatal error will be passed to asynchronous result handlers in place /// of a result if the connection goes down before a result returns, or if the /// function to be called cannot be negotiated with the remote. class ResponseAbandoned : public ErrorInfo { public: static char ID; std::error_code convertToErrorCode() const override; void log(raw_ostream &OS) const override; }; /// This error is returned if the remote does not have a handler installed for /// the given RPC function. class CouldNotNegotiate : public ErrorInfo { public: static char ID; CouldNotNegotiate(std::string Signature); std::error_code convertToErrorCode() const override; void log(raw_ostream &OS) const override; const std::string &getSignature() const { return Signature; } private: std::string Signature; }; template class RPCFunction; // RPC Function class. // DerivedFunc should be a user defined class with a static 'getName()' method // returning a const char* representing the function's name. template class RPCFunction { public: /// User defined function type. using Type = RetT(ArgTs...); /// Return type. using ReturnType = RetT; /// Returns the full function prototype as a string. static const char *getPrototype() { static std::string Name = [] { std::string Name; raw_string_ostream(Name) << SerializationTypeName::getName() << " " << DerivedFunc::getName() << "(" << SerializationTypeNameSequence() << ")"; return Name; }(); return Name.data(); } }; /// Allocates RPC function ids during autonegotiation. /// Specializations of this class must provide four members: /// /// static T getInvalidId(): /// Should return a reserved id that will be used to represent missing /// functions during autonegotiation. /// /// static T getResponseId(): /// Should return a reserved id that will be used to send function responses /// (return values). /// /// static T getNegotiateId(): /// Should return a reserved id for the negotiate function, which will be used /// to negotiate ids for user defined functions. /// /// template T allocate(): /// Allocate a unique id for function Func. template class RPCFunctionIdAllocator; /// This specialization of RPCFunctionIdAllocator provides a default /// implementation for integral types. template class RPCFunctionIdAllocator::value>> { public: static T getInvalidId() { return T(0); } static T getResponseId() { return T(1); } static T getNegotiateId() { return T(2); } template T allocate() { return NextId++; } private: T NextId = 3; }; namespace detail { /// Provides a typedef for a tuple containing the decayed argument types. template class RPCFunctionArgsTuple; template class RPCFunctionArgsTuple { public: using Type = std::tuple>...>; }; // ResultTraits provides typedefs and utilities specific to the return type // of functions. template class ResultTraits { public: // The return type wrapped in llvm::Expected. using ErrorReturnType = Expected; #ifdef _MSC_VER // The ErrorReturnType wrapped in a std::promise. using ReturnPromiseType = std::promise>; // The ErrorReturnType wrapped in a std::future. using ReturnFutureType = std::future>; #else // The ErrorReturnType wrapped in a std::promise. using ReturnPromiseType = std::promise; // The ErrorReturnType wrapped in a std::future. using ReturnFutureType = std::future; #endif // Create a 'blank' value of the ErrorReturnType, ready and safe to // overwrite. static ErrorReturnType createBlankErrorReturnValue() { return ErrorReturnType(RetT()); } // Consume an abandoned ErrorReturnType. static void consumeAbandoned(ErrorReturnType RetOrErr) { consumeError(RetOrErr.takeError()); } }; // ResultTraits specialization for void functions. template <> class ResultTraits { public: // For void functions, ErrorReturnType is llvm::Error. using ErrorReturnType = Error; #ifdef _MSC_VER // The ErrorReturnType wrapped in a std::promise. using ReturnPromiseType = std::promise; // The ErrorReturnType wrapped in a std::future. using ReturnFutureType = std::future; #else // The ErrorReturnType wrapped in a std::promise. using ReturnPromiseType = std::promise; // The ErrorReturnType wrapped in a std::future. using ReturnFutureType = std::future; #endif // Create a 'blank' value of the ErrorReturnType, ready and safe to // overwrite. static ErrorReturnType createBlankErrorReturnValue() { return ErrorReturnType::success(); } // Consume an abandoned ErrorReturnType. static void consumeAbandoned(ErrorReturnType Err) { consumeError(std::move(Err)); } }; // ResultTraits is equivalent to ResultTraits. This allows // handlers for void RPC functions to return either void (in which case they // implicitly succeed) or Error (in which case their error return is // propagated). See usage in HandlerTraits::runHandlerHelper. template <> class ResultTraits : public ResultTraits {}; // ResultTraits> is equivalent to ResultTraits. This allows // handlers for RPC functions returning a T to return either a T (in which // case they implicitly succeed) or Expected (in which case their error // return is propagated). See usage in HandlerTraits::runHandlerHelper. template class ResultTraits> : public ResultTraits {}; // Determines whether an RPC function's defined error return type supports // error return value. template class SupportsErrorReturn { public: static const bool value = false; }; template <> class SupportsErrorReturn { public: static const bool value = true; }; template class SupportsErrorReturn> { public: static const bool value = true; }; // RespondHelper packages return values based on whether or not the declared // RPC function return type supports error returns. template class RespondHelper; // RespondHelper specialization for functions that support error returns. template <> class RespondHelper { public: // Send Expected. template static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, Expected ResultOrErr) { if (!ResultOrErr && ResultOrErr.template errorIsA()) return ResultOrErr.takeError(); // Open the response message. if (auto Err = C.startSendMessage(ResponseId, SeqNo)) return Err; // Serialize the result. if (auto Err = SerializationTraits>:: serialize(C, std::move(ResultOrErr))) return Err; // Close the response message. if (auto Err = C.endSendMessage()) return Err; return C.send(); } template static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, Error Err) { if (Err && Err.isA()) return Err; if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) return Err2; if (auto Err2 = serializeSeq(C, std::move(Err))) return Err2; if (auto Err2 = C.endSendMessage()) return Err2; return C.send(); } }; // RespondHelper specialization for functions that do not support error returns. template <> class RespondHelper { public: template static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, Expected ResultOrErr) { if (auto Err = ResultOrErr.takeError()) return Err; // Open the response message. if (auto Err = C.startSendMessage(ResponseId, SeqNo)) return Err; // Serialize the result. if (auto Err = SerializationTraits::serialize( C, *ResultOrErr)) return Err; // End the response message. if (auto Err = C.endSendMessage()) return Err; return C.send(); } template static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, Error Err) { if (Err) return Err; if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) return Err2; if (auto Err2 = C.endSendMessage()) return Err2; return C.send(); } }; // Send a response of the given wire return type (WireRetT) over the // channel, with the given sequence number. template Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, Expected ResultOrErr) { return RespondHelper::value>:: template sendResult(C, ResponseId, SeqNo, std::move(ResultOrErr)); } // Send an empty response message on the given channel to indicate that // the handler ran. template Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, Error Err) { return RespondHelper::value>::sendResult( C, ResponseId, SeqNo, std::move(Err)); } // Converts a given type to the equivalent error return type. template class WrappedHandlerReturn { public: using Type = Expected; }; template class WrappedHandlerReturn> { public: using Type = Expected; }; template <> class WrappedHandlerReturn { public: using Type = Error; }; template <> class WrappedHandlerReturn { public: using Type = Error; }; template <> class WrappedHandlerReturn { public: using Type = Error; }; // Traits class that strips the response function from the list of handler // arguments. template class AsyncHandlerTraits; template class AsyncHandlerTraits)>, ArgTs...)> { public: using Type = Error(ArgTs...); using ResultType = Expected; }; template class AsyncHandlerTraits, ArgTs...)> { public: using Type = Error(ArgTs...); using ResultType = Error; }; template class AsyncHandlerTraits, ArgTs...)> { public: using Type = Error(ArgTs...); using ResultType = Error; }; template class AsyncHandlerTraits, ArgTs...)> { public: using Type = Error(ArgTs...); using ResultType = Error; }; template class AsyncHandlerTraits : public AsyncHandlerTraits, ArgTs...)> {}; // This template class provides utilities related to RPC function handlers. // The base case applies to non-function types (the template class is // specialized for function types) and inherits from the appropriate // speciilization for the given non-function type's call operator. template class HandlerTraits : public HandlerTraits< decltype(&std::remove_reference::type::operator())> {}; // Traits for handlers with a given function type. template class HandlerTraits { public: // Function type of the handler. using Type = RetT(ArgTs...); // Return type of the handler. using ReturnType = RetT; // Call the given handler with the given arguments. template static typename WrappedHandlerReturn::Type unpackAndRun(HandlerT &Handler, std::tuple &Args) { return unpackAndRunHelper(Handler, Args, std::index_sequence_for()); } // Call the given handler with the given arguments. template static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, std::tuple &Args) { return unpackAndRunAsyncHelper(Handler, Responder, Args, std::index_sequence_for()); } // Call the given handler with the given arguments. template static std::enable_if_t< std::is_void::ReturnType>::value, Error> run(HandlerT &Handler, ArgTs &&...Args) { Handler(std::move(Args)...); return Error::success(); } template static std::enable_if_t< !std::is_void::ReturnType>::value, typename HandlerTraits::ReturnType> run(HandlerT &Handler, TArgTs... Args) { return Handler(std::move(Args)...); } // Serialize arguments to the channel. template static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { return SequenceSerialization::serialize(C, CArgs...); } // Deserialize arguments from the channel. template static Error deserializeArgs(ChannelT &C, std::tuple &Args) { return deserializeArgsHelper(C, Args, std::index_sequence_for()); } private: template static Error deserializeArgsHelper(ChannelT &C, std::tuple &Args, std::index_sequence _) { return SequenceSerialization::deserialize( C, std::get(Args)...); } template static typename WrappedHandlerReturn< typename HandlerTraits::ReturnType>::Type unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, std::index_sequence) { return run(Handler, std::move(std::get(Args))...); } template static typename WrappedHandlerReturn< typename HandlerTraits::ReturnType>::Type unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, ArgTuple &Args, std::index_sequence) { return run(Handler, Responder, std::move(std::get(Args))...); } }; // Handler traits for free functions. template class HandlerTraits : public HandlerTraits { }; // Handler traits for class methods (especially call operators for lambdas). template class HandlerTraits : public HandlerTraits {}; // Handler traits for const class methods (especially call operators for // lambdas). template class HandlerTraits : public HandlerTraits {}; // Utility to peel the Expected wrapper off a response handler error type. template class ResponseHandlerArg; template class ResponseHandlerArg)> { public: using ArgType = Expected; using UnwrappedArgType = ArgT; }; template class ResponseHandlerArg)> { public: using ArgType = Expected; using UnwrappedArgType = ArgT; }; template <> class ResponseHandlerArg { public: using ArgType = Error; }; template <> class ResponseHandlerArg { public: using ArgType = Error; }; // ResponseHandler represents a handler for a not-yet-received function call // result. template class ResponseHandler { public: virtual ~ResponseHandler() {} // Reads the function result off the wire and acts on it. The meaning of // "act" will depend on how this method is implemented in any given // ResponseHandler subclass but could, for example, mean running a // user-specified handler or setting a promise value. virtual Error handleResponse(ChannelT &C) = 0; // Abandons this outstanding result. virtual void abandon() = 0; // Create an error instance representing an abandoned response. static Error createAbandonedResponseError() { return make_error(); } }; // ResponseHandler subclass for RPC functions with non-void returns. template class ResponseHandlerImpl : public ResponseHandler { public: ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} // Handle the result by deserializing it from the channel then passing it // to the user defined handler. Error handleResponse(ChannelT &C) override { using UnwrappedArgType = typename ResponseHandlerArg< typename HandlerTraits::Type>::UnwrappedArgType; UnwrappedArgType Result; if (auto Err = SerializationTraits::deserialize(C, Result)) return Err; if (auto Err = C.endReceiveMessage()) return Err; return Handler(std::move(Result)); } // Abandon this response by calling the handler with an 'abandoned response' // error. void abandon() override { if (auto Err = Handler(this->createAbandonedResponseError())) { // Handlers should not fail when passed an abandoned response error. report_fatal_error(std::move(Err)); } } private: HandlerT Handler; }; // ResponseHandler subclass for RPC functions with void returns. template class ResponseHandlerImpl : public ResponseHandler { public: ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} // Handle the result (no actual value, just a notification that the function // has completed on the remote end) by calling the user-defined handler with // Error::success(). Error handleResponse(ChannelT &C) override { if (auto Err = C.endReceiveMessage()) return Err; return Handler(Error::success()); } // Abandon this response by calling the handler with an 'abandoned response' // error. void abandon() override { if (auto Err = Handler(this->createAbandonedResponseError())) { // Handlers should not fail when passed an abandoned response error. report_fatal_error(std::move(Err)); } } private: HandlerT Handler; }; template class ResponseHandlerImpl, HandlerT> : public ResponseHandler { public: ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} // Handle the result by deserializing it from the channel then passing it // to the user defined handler. Error handleResponse(ChannelT &C) override { using HandlerArgType = typename ResponseHandlerArg< typename HandlerTraits::Type>::ArgType; HandlerArgType Result((typename HandlerArgType::value_type())); if (auto Err = SerializationTraits, HandlerArgType>::deserialize(C, Result)) return Err; if (auto Err = C.endReceiveMessage()) return Err; return Handler(std::move(Result)); } // Abandon this response by calling the handler with an 'abandoned response' // error. void abandon() override { if (auto Err = Handler(this->createAbandonedResponseError())) { // Handlers should not fail when passed an abandoned response error. report_fatal_error(std::move(Err)); } } private: HandlerT Handler; }; template class ResponseHandlerImpl : public ResponseHandler { public: ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} // Handle the result by deserializing it from the channel then passing it // to the user defined handler. Error handleResponse(ChannelT &C) override { Error Result = Error::success(); if (auto Err = SerializationTraits::deserialize( C, Result)) { consumeError(std::move(Result)); return Err; } if (auto Err = C.endReceiveMessage()) { consumeError(std::move(Result)); return Err; } return Handler(std::move(Result)); } // Abandon this response by calling the handler with an 'abandoned response' // error. void abandon() override { if (auto Err = Handler(this->createAbandonedResponseError())) { // Handlers should not fail when passed an abandoned response error. report_fatal_error(std::move(Err)); } } private: HandlerT Handler; }; // Create a ResponseHandler from a given user handler. template std::unique_ptr> createResponseHandler(HandlerT H) { return std::make_unique>( std::move(H)); } // Helper for wrapping member functions up as functors. This is useful for // installing methods as result handlers. template class MemberFnWrapper { public: using MethodT = RetT (ClassT::*)(ArgTs...); MemberFnWrapper(ClassT &Instance, MethodT Method) : Instance(Instance), Method(Method) {} RetT operator()(ArgTs &&...Args) { return (Instance.*Method)(std::move(Args)...); } private: ClassT &Instance; MethodT Method; }; // Helper that provides a Functor for deserializing arguments. template class ReadArgs { public: Error operator()() { return Error::success(); } }; template class ReadArgs : public ReadArgs { public: ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs(Args...), Arg(Arg) {} Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) { this->Arg = std::move(ArgVal); return ReadArgs::operator()(ArgVals...); } private: ArgT &Arg; }; // Manage sequence numbers. template class SequenceNumberManager { public: // Reset, making all sequence numbers available. void reset() { std::lock_guard Lock(SeqNoLock); NextSequenceNumber = 0; FreeSequenceNumbers.clear(); } // Get the next available sequence number. Will re-use numbers that have // been released. SequenceNumberT getSequenceNumber() { std::lock_guard Lock(SeqNoLock); if (FreeSequenceNumbers.empty()) return NextSequenceNumber++; auto SequenceNumber = FreeSequenceNumbers.back(); FreeSequenceNumbers.pop_back(); return SequenceNumber; } // Release a sequence number, making it available for re-use. void releaseSequenceNumber(SequenceNumberT SequenceNumber) { std::lock_guard Lock(SeqNoLock); FreeSequenceNumbers.push_back(SequenceNumber); } private: std::mutex SeqNoLock; SequenceNumberT NextSequenceNumber = 0; std::vector FreeSequenceNumbers; }; // Checks that predicate P holds for each corresponding pair of type arguments // from T1 and T2 tuple. template