1658 lines
56 KiB
C
1658 lines
56 KiB
C
|
//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// Utilities to support construction of simple RPC APIs.
|
||
|
//
|
||
|
// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
|
||
|
// programmers, high performance, low memory overhead, and efficient use of the
|
||
|
// communications channel.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
|
||
|
#define LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
|
||
|
|
||
|
#include <map>
|
||
|
#include <thread>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "llvm/ADT/STLExtras.h"
|
||
|
#include "llvm/ExecutionEngine/Orc/Shared/OrcError.h"
|
||
|
#include "llvm/ExecutionEngine/Orc/Shared/Serialization.h"
|
||
|
#include "llvm/Support/MSVCErrorWorkarounds.h"
|
||
|
|
||
|
#include <future>
|
||
|
|
||
|
namespace llvm {
|
||
|
namespace orc {
|
||
|
namespace shared {
|
||
|
|
||
|
/// Base class of all fatal RPC errors (those that necessarily result in the
|
||
|
/// termination of the RPC session).
|
||
|
class RPCFatalError : public ErrorInfo<RPCFatalError> {
|
||
|
public:
|
||
|
static char ID;
|
||
|
};
|
||
|
|
||
|
/// RPCConnectionClosed is returned from RPC operations if the RPC connection
|
||
|
/// has already been closed due to either an error or graceful disconnection.
|
||
|
class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
|
||
|
public:
|
||
|
static char ID;
|
||
|
std::error_code convertToErrorCode() const override;
|
||
|
void log(raw_ostream &OS) const override;
|
||
|
};
|
||
|
|
||
|
/// BadFunctionCall is returned from handleOne when the remote makes a call with
|
||
|
/// an unrecognized function id.
|
||
|
///
|
||
|
/// This error is fatal because Orc RPC needs to know how to parse a function
|
||
|
/// call to know where the next call starts, and if it doesn't recognize the
|
||
|
/// function id it cannot parse the call.
|
||
|
template <typename FnIdT, typename SeqNoT>
|
||
|
class BadFunctionCall
|
||
|
: public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
|
||
|
public:
|
||
|
static char ID;
|
||
|
|
||
|
BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
|
||
|
: FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
|
||
|
|
||
|
std::error_code convertToErrorCode() const override {
|
||
|
return orcError(OrcErrorCode::UnexpectedRPCCall);
|
||
|
}
|
||
|
|
||
|
void log(raw_ostream &OS) const override {
|
||
|
OS << "Call to invalid RPC function id '" << FnId
|
||
|
<< "' with "
|
||
|
"sequence number "
|
||
|
<< SeqNo;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
FnIdT FnId;
|
||
|
SeqNoT SeqNo;
|
||
|
};
|
||
|
|
||
|
template <typename FnIdT, typename SeqNoT>
|
||
|
char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
|
||
|
|
||
|
/// InvalidSequenceNumberForResponse is returned from handleOne when a response
|
||
|
/// call arrives with a sequence number that doesn't correspond to any in-flight
|
||
|
/// function call.
|
||
|
///
|
||
|
/// This error is fatal because Orc RPC needs to know how to parse the rest of
|
||
|
/// the response call to know where the next call starts, and if it doesn't have
|
||
|
/// a result parser for this sequence number it can't do that.
|
||
|
template <typename SeqNoT>
|
||
|
class InvalidSequenceNumberForResponse
|
||
|
: public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>,
|
||
|
RPCFatalError> {
|
||
|
public:
|
||
|
static char ID;
|
||
|
|
||
|
InvalidSequenceNumberForResponse(SeqNoT SeqNo) : SeqNo(std::move(SeqNo)) {}
|
||
|
|
||
|
std::error_code convertToErrorCode() const override {
|
||
|
return orcError(OrcErrorCode::UnexpectedRPCCall);
|
||
|
};
|
||
|
|
||
|
void log(raw_ostream &OS) const override {
|
||
|
OS << "Response has unknown sequence number " << SeqNo;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
SeqNoT SeqNo;
|
||
|
};
|
||
|
|
||
|
template <typename SeqNoT>
|
||
|
char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
|
||
|
|
||
|
/// This non-fatal error will be passed to asynchronous result handlers in place
|
||
|
/// of a result if the connection goes down before a result returns, or if the
|
||
|
/// function to be called cannot be negotiated with the remote.
|
||
|
class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
|
||
|
public:
|
||
|
static char ID;
|
||
|
|
||
|
std::error_code convertToErrorCode() const override;
|
||
|
void log(raw_ostream &OS) const override;
|
||
|
};
|
||
|
|
||
|
/// This error is returned if the remote does not have a handler installed for
|
||
|
/// the given RPC function.
|
||
|
class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
|
||
|
public:
|
||
|
static char ID;
|
||
|
|
||
|
CouldNotNegotiate(std::string Signature);
|
||
|
std::error_code convertToErrorCode() const override;
|
||
|
void log(raw_ostream &OS) const override;
|
||
|
const std::string &getSignature() const { return Signature; }
|
||
|
|
||
|
private:
|
||
|
std::string Signature;
|
||
|
};
|
||
|
|
||
|
template <typename DerivedFunc, typename FnT> class RPCFunction;
|
||
|
|
||
|
// RPC Function class.
|
||
|
// DerivedFunc should be a user defined class with a static 'getName()' method
|
||
|
// returning a const char* representing the function's name.
|
||
|
template <typename DerivedFunc, typename RetT, typename... ArgTs>
|
||
|
class RPCFunction<DerivedFunc, RetT(ArgTs...)> {
|
||
|
public:
|
||
|
/// User defined function type.
|
||
|
using Type = RetT(ArgTs...);
|
||
|
|
||
|
/// Return type.
|
||
|
using ReturnType = RetT;
|
||
|
|
||
|
/// Returns the full function prototype as a string.
|
||
|
static const char *getPrototype() {
|
||
|
static std::string Name = [] {
|
||
|
std::string Name;
|
||
|
raw_string_ostream(Name)
|
||
|
<< SerializationTypeName<RetT>::getName() << " "
|
||
|
<< DerivedFunc::getName() << "("
|
||
|
<< SerializationTypeNameSequence<ArgTs...>() << ")";
|
||
|
return Name;
|
||
|
}();
|
||
|
return Name.data();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Allocates RPC function ids during autonegotiation.
|
||
|
/// Specializations of this class must provide four members:
|
||
|
///
|
||
|
/// static T getInvalidId():
|
||
|
/// Should return a reserved id that will be used to represent missing
|
||
|
/// functions during autonegotiation.
|
||
|
///
|
||
|
/// static T getResponseId():
|
||
|
/// Should return a reserved id that will be used to send function responses
|
||
|
/// (return values).
|
||
|
///
|
||
|
/// static T getNegotiateId():
|
||
|
/// Should return a reserved id for the negotiate function, which will be used
|
||
|
/// to negotiate ids for user defined functions.
|
||
|
///
|
||
|
/// template <typename Func> T allocate():
|
||
|
/// Allocate a unique id for function Func.
|
||
|
template <typename T, typename = void> class RPCFunctionIdAllocator;
|
||
|
|
||
|
/// This specialization of RPCFunctionIdAllocator provides a default
|
||
|
/// implementation for integral types.
|
||
|
template <typename T>
|
||
|
class RPCFunctionIdAllocator<T, std::enable_if_t<std::is_integral<T>::value>> {
|
||
|
public:
|
||
|
static T getInvalidId() { return T(0); }
|
||
|
static T getResponseId() { return T(1); }
|
||
|
static T getNegotiateId() { return T(2); }
|
||
|
|
||
|
template <typename Func> T allocate() { return NextId++; }
|
||
|
|
||
|
private:
|
||
|
T NextId = 3;
|
||
|
};
|
||
|
|
||
|
namespace detail {
|
||
|
|
||
|
/// Provides a typedef for a tuple containing the decayed argument types.
|
||
|
template <typename T> class RPCFunctionArgsTuple;
|
||
|
|
||
|
template <typename RetT, typename... ArgTs>
|
||
|
class RPCFunctionArgsTuple<RetT(ArgTs...)> {
|
||
|
public:
|
||
|
using Type = std::tuple<std::decay_t<std::remove_reference_t<ArgTs>>...>;
|
||
|
};
|
||
|
|
||
|
// ResultTraits provides typedefs and utilities specific to the return type
|
||
|
// of functions.
|
||
|
template <typename RetT> class ResultTraits {
|
||
|
public:
|
||
|
// The return type wrapped in llvm::Expected.
|
||
|
using ErrorReturnType = Expected<RetT>;
|
||
|
|
||
|
#ifdef _MSC_VER
|
||
|
// The ErrorReturnType wrapped in a std::promise.
|
||
|
using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
|
||
|
|
||
|
// The ErrorReturnType wrapped in a std::future.
|
||
|
using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
|
||
|
#else
|
||
|
// The ErrorReturnType wrapped in a std::promise.
|
||
|
using ReturnPromiseType = std::promise<ErrorReturnType>;
|
||
|
|
||
|
// The ErrorReturnType wrapped in a std::future.
|
||
|
using ReturnFutureType = std::future<ErrorReturnType>;
|
||
|
#endif
|
||
|
|
||
|
// Create a 'blank' value of the ErrorReturnType, ready and safe to
|
||
|
// overwrite.
|
||
|
static ErrorReturnType createBlankErrorReturnValue() {
|
||
|
return ErrorReturnType(RetT());
|
||
|
}
|
||
|
|
||
|
// Consume an abandoned ErrorReturnType.
|
||
|
static void consumeAbandoned(ErrorReturnType RetOrErr) {
|
||
|
consumeError(RetOrErr.takeError());
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// ResultTraits specialization for void functions.
|
||
|
template <> class ResultTraits<void> {
|
||
|
public:
|
||
|
// For void functions, ErrorReturnType is llvm::Error.
|
||
|
using ErrorReturnType = Error;
|
||
|
|
||
|
#ifdef _MSC_VER
|
||
|
// The ErrorReturnType wrapped in a std::promise.
|
||
|
using ReturnPromiseType = std::promise<MSVCPError>;
|
||
|
|
||
|
// The ErrorReturnType wrapped in a std::future.
|
||
|
using ReturnFutureType = std::future<MSVCPError>;
|
||
|
#else
|
||
|
// The ErrorReturnType wrapped in a std::promise.
|
||
|
using ReturnPromiseType = std::promise<ErrorReturnType>;
|
||
|
|
||
|
// The ErrorReturnType wrapped in a std::future.
|
||
|
using ReturnFutureType = std::future<ErrorReturnType>;
|
||
|
#endif
|
||
|
|
||
|
// Create a 'blank' value of the ErrorReturnType, ready and safe to
|
||
|
// overwrite.
|
||
|
static ErrorReturnType createBlankErrorReturnValue() {
|
||
|
return ErrorReturnType::success();
|
||
|
}
|
||
|
|
||
|
// Consume an abandoned ErrorReturnType.
|
||
|
static void consumeAbandoned(ErrorReturnType Err) {
|
||
|
consumeError(std::move(Err));
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
|
||
|
// handlers for void RPC functions to return either void (in which case they
|
||
|
// implicitly succeed) or Error (in which case their error return is
|
||
|
// propagated). See usage in HandlerTraits::runHandlerHelper.
|
||
|
template <> class ResultTraits<Error> : public ResultTraits<void> {};
|
||
|
|
||
|
// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
|
||
|
// handlers for RPC functions returning a T to return either a T (in which
|
||
|
// case they implicitly succeed) or Expected<T> (in which case their error
|
||
|
// return is propagated). See usage in HandlerTraits::runHandlerHelper.
|
||
|
template <typename RetT>
|
||
|
class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
|
||
|
|
||
|
// Determines whether an RPC function's defined error return type supports
|
||
|
// error return value.
|
||
|
template <typename T> class SupportsErrorReturn {
|
||
|
public:
|
||
|
static const bool value = false;
|
||
|
};
|
||
|
|
||
|
template <> class SupportsErrorReturn<Error> {
|
||
|
public:
|
||
|
static const bool value = true;
|
||
|
};
|
||
|
|
||
|
template <typename T> class SupportsErrorReturn<Expected<T>> {
|
||
|
public:
|
||
|
static const bool value = true;
|
||
|
};
|
||
|
|
||
|
// RespondHelper packages return values based on whether or not the declared
|
||
|
// RPC function return type supports error returns.
|
||
|
template <bool FuncSupportsErrorReturn> class RespondHelper;
|
||
|
|
||
|
// RespondHelper specialization for functions that support error returns.
|
||
|
template <> class RespondHelper<true> {
|
||
|
public:
|
||
|
// Send Expected<T>.
|
||
|
template <typename WireRetT, typename HandlerRetT, typename ChannelT,
|
||
|
typename FunctionIdT, typename SequenceNumberT>
|
||
|
static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
|
||
|
SequenceNumberT SeqNo,
|
||
|
Expected<HandlerRetT> ResultOrErr) {
|
||
|
if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
|
||
|
return ResultOrErr.takeError();
|
||
|
|
||
|
// Open the response message.
|
||
|
if (auto Err = C.startSendMessage(ResponseId, SeqNo))
|
||
|
return Err;
|
||
|
|
||
|
// Serialize the result.
|
||
|
if (auto Err =
|
||
|
SerializationTraits<ChannelT, WireRetT, Expected<HandlerRetT>>::
|
||
|
serialize(C, std::move(ResultOrErr)))
|
||
|
return Err;
|
||
|
|
||
|
// Close the response message.
|
||
|
if (auto Err = C.endSendMessage())
|
||
|
return Err;
|
||
|
return C.send();
|
||
|
}
|
||
|
|
||
|
template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
|
||
|
static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
|
||
|
SequenceNumberT SeqNo, Error Err) {
|
||
|
if (Err && Err.isA<RPCFatalError>())
|
||
|
return Err;
|
||
|
if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
|
||
|
return Err2;
|
||
|
if (auto Err2 = serializeSeq(C, std::move(Err)))
|
||
|
return Err2;
|
||
|
if (auto Err2 = C.endSendMessage())
|
||
|
return Err2;
|
||
|
return C.send();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// RespondHelper specialization for functions that do not support error returns.
|
||
|
template <> class RespondHelper<false> {
|
||
|
public:
|
||
|
template <typename WireRetT, typename HandlerRetT, typename ChannelT,
|
||
|
typename FunctionIdT, typename SequenceNumberT>
|
||
|
static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
|
||
|
SequenceNumberT SeqNo,
|
||
|
Expected<HandlerRetT> ResultOrErr) {
|
||
|
if (auto Err = ResultOrErr.takeError())
|
||
|
return Err;
|
||
|
|
||
|
// Open the response message.
|
||
|
if (auto Err = C.startSendMessage(ResponseId, SeqNo))
|
||
|
return Err;
|
||
|
|
||
|
// Serialize the result.
|
||
|
if (auto Err =
|
||
|
SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
|
||
|
C, *ResultOrErr))
|
||
|
return Err;
|
||
|
|
||
|
// End the response message.
|
||
|
if (auto Err = C.endSendMessage())
|
||
|
return Err;
|
||
|
|
||
|
return C.send();
|
||
|
}
|
||
|
|
||
|
template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
|
||
|
static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
|
||
|
SequenceNumberT SeqNo, Error Err) {
|
||
|
if (Err)
|
||
|
return Err;
|
||
|
if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
|
||
|
return Err2;
|
||
|
if (auto Err2 = C.endSendMessage())
|
||
|
return Err2;
|
||
|
return C.send();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// Send a response of the given wire return type (WireRetT) over the
|
||
|
// channel, with the given sequence number.
|
||
|
template <typename WireRetT, typename HandlerRetT, typename ChannelT,
|
||
|
typename FunctionIdT, typename SequenceNumberT>
|
||
|
Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
|
||
|
Expected<HandlerRetT> ResultOrErr) {
|
||
|
return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
|
||
|
template sendResult<WireRetT>(C, ResponseId, SeqNo,
|
||
|
std::move(ResultOrErr));
|
||
|
}
|
||
|
|
||
|
// Send an empty response message on the given channel to indicate that
|
||
|
// the handler ran.
|
||
|
template <typename WireRetT, typename ChannelT, typename FunctionIdT,
|
||
|
typename SequenceNumberT>
|
||
|
Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
|
||
|
Error Err) {
|
||
|
return RespondHelper<SupportsErrorReturn<WireRetT>::value>::sendResult(
|
||
|
C, ResponseId, SeqNo, std::move(Err));
|
||
|
}
|
||
|
|
||
|
// Converts a given type to the equivalent error return type.
|
||
|
template <typename T> class WrappedHandlerReturn {
|
||
|
public:
|
||
|
using Type = Expected<T>;
|
||
|
};
|
||
|
|
||
|
template <typename T> class WrappedHandlerReturn<Expected<T>> {
|
||
|
public:
|
||
|
using Type = Expected<T>;
|
||
|
};
|
||
|
|
||
|
template <> class WrappedHandlerReturn<void> {
|
||
|
public:
|
||
|
using Type = Error;
|
||
|
};
|
||
|
|
||
|
template <> class WrappedHandlerReturn<Error> {
|
||
|
public:
|
||
|
using Type = Error;
|
||
|
};
|
||
|
|
||
|
template <> class WrappedHandlerReturn<ErrorSuccess> {
|
||
|
public:
|
||
|
using Type = Error;
|
||
|
};
|
||
|
|
||
|
// Traits class that strips the response function from the list of handler
|
||
|
// arguments.
|
||
|
template <typename FnT> class AsyncHandlerTraits;
|
||
|
|
||
|
template <typename ResultT, typename... ArgTs>
|
||
|
class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>,
|
||
|
ArgTs...)> {
|
||
|
public:
|
||
|
using Type = Error(ArgTs...);
|
||
|
using ResultType = Expected<ResultT>;
|
||
|
};
|
||
|
|
||
|
template <typename... ArgTs>
|
||
|
class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
|
||
|
public:
|
||
|
using Type = Error(ArgTs...);
|
||
|
using ResultType = Error;
|
||
|
};
|
||
|
|
||
|
template <typename... ArgTs>
|
||
|
class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
|
||
|
public:
|
||
|
using Type = Error(ArgTs...);
|
||
|
using ResultType = Error;
|
||
|
};
|
||
|
|
||
|
template <typename... ArgTs>
|
||
|
class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
|
||
|
public:
|
||
|
using Type = Error(ArgTs...);
|
||
|
using ResultType = Error;
|
||
|
};
|
||
|
|
||
|
template <typename ResponseHandlerT, typename... ArgTs>
|
||
|
class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)>
|
||
|
: public AsyncHandlerTraits<Error(std::decay_t<ResponseHandlerT>,
|
||
|
ArgTs...)> {};
|
||
|
|
||
|
// This template class provides utilities related to RPC function handlers.
|
||
|
// The base case applies to non-function types (the template class is
|
||
|
// specialized for function types) and inherits from the appropriate
|
||
|
// speciilization for the given non-function type's call operator.
|
||
|
template <typename HandlerT>
|
||
|
class HandlerTraits
|
||
|
: public HandlerTraits<
|
||
|
decltype(&std::remove_reference<HandlerT>::type::operator())> {};
|
||
|
|
||
|
// Traits for handlers with a given function type.
|
||
|
template <typename RetT, typename... ArgTs>
|
||
|
class HandlerTraits<RetT(ArgTs...)> {
|
||
|
public:
|
||
|
// Function type of the handler.
|
||
|
using Type = RetT(ArgTs...);
|
||
|
|
||
|
// Return type of the handler.
|
||
|
using ReturnType = RetT;
|
||
|
|
||
|
// Call the given handler with the given arguments.
|
||
|
template <typename HandlerT, typename... TArgTs>
|
||
|
static typename WrappedHandlerReturn<RetT>::Type
|
||
|
unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
|
||
|
return unpackAndRunHelper(Handler, Args,
|
||
|
std::index_sequence_for<TArgTs...>());
|
||
|
}
|
||
|
|
||
|
// Call the given handler with the given arguments.
|
||
|
template <typename HandlerT, typename ResponderT, typename... TArgTs>
|
||
|
static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
|
||
|
std::tuple<TArgTs...> &Args) {
|
||
|
return unpackAndRunAsyncHelper(Handler, Responder, Args,
|
||
|
std::index_sequence_for<TArgTs...>());
|
||
|
}
|
||
|
|
||
|
// Call the given handler with the given arguments.
|
||
|
template <typename HandlerT>
|
||
|
static std::enable_if_t<
|
||
|
std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, Error>
|
||
|
run(HandlerT &Handler, ArgTs &&...Args) {
|
||
|
Handler(std::move(Args)...);
|
||
|
return Error::success();
|
||
|
}
|
||
|
|
||
|
template <typename HandlerT, typename... TArgTs>
|
||
|
static std::enable_if_t<
|
||
|
!std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
|
||
|
typename HandlerTraits<HandlerT>::ReturnType>
|
||
|
run(HandlerT &Handler, TArgTs... Args) {
|
||
|
return Handler(std::move(Args)...);
|
||
|
}
|
||
|
|
||
|
// Serialize arguments to the channel.
|
||
|
template <typename ChannelT, typename... CArgTs>
|
||
|
static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
|
||
|
return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
|
||
|
}
|
||
|
|
||
|
// Deserialize arguments from the channel.
|
||
|
template <typename ChannelT, typename... CArgTs>
|
||
|
static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
|
||
|
return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>());
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
template <typename ChannelT, typename... CArgTs, size_t... Indexes>
|
||
|
static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
|
||
|
std::index_sequence<Indexes...> _) {
|
||
|
return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
|
||
|
C, std::get<Indexes>(Args)...);
|
||
|
}
|
||
|
|
||
|
template <typename HandlerT, typename ArgTuple, size_t... Indexes>
|
||
|
static typename WrappedHandlerReturn<
|
||
|
typename HandlerTraits<HandlerT>::ReturnType>::Type
|
||
|
unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
|
||
|
std::index_sequence<Indexes...>) {
|
||
|
return run(Handler, std::move(std::get<Indexes>(Args))...);
|
||
|
}
|
||
|
|
||
|
template <typename HandlerT, typename ResponderT, typename ArgTuple,
|
||
|
size_t... Indexes>
|
||
|
static typename WrappedHandlerReturn<
|
||
|
typename HandlerTraits<HandlerT>::ReturnType>::Type
|
||
|
unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
|
||
|
ArgTuple &Args, std::index_sequence<Indexes...>) {
|
||
|
return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// Handler traits for free functions.
|
||
|
template <typename RetT, typename... ArgTs>
|
||
|
class HandlerTraits<RetT (*)(ArgTs...)> : public HandlerTraits<RetT(ArgTs...)> {
|
||
|
};
|
||
|
|
||
|
// Handler traits for class methods (especially call operators for lambdas).
|
||
|
template <typename Class, typename RetT, typename... ArgTs>
|
||
|
class HandlerTraits<RetT (Class::*)(ArgTs...)>
|
||
|
: public HandlerTraits<RetT(ArgTs...)> {};
|
||
|
|
||
|
// Handler traits for const class methods (especially call operators for
|
||
|
// lambdas).
|
||
|
template <typename Class, typename RetT, typename... ArgTs>
|
||
|
class HandlerTraits<RetT (Class::*)(ArgTs...) const>
|
||
|
: public HandlerTraits<RetT(ArgTs...)> {};
|
||
|
|
||
|
// Utility to peel the Expected wrapper off a response handler error type.
|
||
|
template <typename HandlerT> class ResponseHandlerArg;
|
||
|
|
||
|
template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
|
||
|
public:
|
||
|
using ArgType = Expected<ArgT>;
|
||
|
using UnwrappedArgType = ArgT;
|
||
|
};
|
||
|
|
||
|
template <typename ArgT>
|
||
|
class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
|
||
|
public:
|
||
|
using ArgType = Expected<ArgT>;
|
||
|
using UnwrappedArgType = ArgT;
|
||
|
};
|
||
|
|
||
|
template <> class ResponseHandlerArg<Error(Error)> {
|
||
|
public:
|
||
|
using ArgType = Error;
|
||
|
};
|
||
|
|
||
|
template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
|
||
|
public:
|
||
|
using ArgType = Error;
|
||
|
};
|
||
|
|
||
|
// ResponseHandler represents a handler for a not-yet-received function call
|
||
|
// result.
|
||
|
template <typename ChannelT> class ResponseHandler {
|
||
|
public:
|
||
|
virtual ~ResponseHandler() {}
|
||
|
|
||
|
// Reads the function result off the wire and acts on it. The meaning of
|
||
|
// "act" will depend on how this method is implemented in any given
|
||
|
// ResponseHandler subclass but could, for example, mean running a
|
||
|
// user-specified handler or setting a promise value.
|
||
|
virtual Error handleResponse(ChannelT &C) = 0;
|
||
|
|
||
|
// Abandons this outstanding result.
|
||
|
virtual void abandon() = 0;
|
||
|
|
||
|
// Create an error instance representing an abandoned response.
|
||
|
static Error createAbandonedResponseError() {
|
||
|
return make_error<ResponseAbandoned>();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// ResponseHandler subclass for RPC functions with non-void returns.
|
||
|
template <typename ChannelT, typename FuncRetT, typename HandlerT>
|
||
|
class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
|
||
|
public:
|
||
|
ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
|
||
|
|
||
|
// Handle the result by deserializing it from the channel then passing it
|
||
|
// to the user defined handler.
|
||
|
Error handleResponse(ChannelT &C) override {
|
||
|
using UnwrappedArgType = typename ResponseHandlerArg<
|
||
|
typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
|
||
|
UnwrappedArgType Result;
|
||
|
if (auto Err =
|
||
|
SerializationTraits<ChannelT, FuncRetT,
|
||
|
UnwrappedArgType>::deserialize(C, Result))
|
||
|
return Err;
|
||
|
if (auto Err = C.endReceiveMessage())
|
||
|
return Err;
|
||
|
return Handler(std::move(Result));
|
||
|
}
|
||
|
|
||
|
// Abandon this response by calling the handler with an 'abandoned response'
|
||
|
// error.
|
||
|
void abandon() override {
|
||
|
if (auto Err = Handler(this->createAbandonedResponseError())) {
|
||
|
// Handlers should not fail when passed an abandoned response error.
|
||
|
report_fatal_error(std::move(Err));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
HandlerT Handler;
|
||
|
};
|
||
|
|
||
|
// ResponseHandler subclass for RPC functions with void returns.
|
||
|
template <typename ChannelT, typename HandlerT>
|
||
|
class ResponseHandlerImpl<ChannelT, void, HandlerT>
|
||
|
: public ResponseHandler<ChannelT> {
|
||
|
public:
|
||
|
ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
|
||
|
|
||
|
// Handle the result (no actual value, just a notification that the function
|
||
|
// has completed on the remote end) by calling the user-defined handler with
|
||
|
// Error::success().
|
||
|
Error handleResponse(ChannelT &C) override {
|
||
|
if (auto Err = C.endReceiveMessage())
|
||
|
return Err;
|
||
|
return Handler(Error::success());
|
||
|
}
|
||
|
|
||
|
// Abandon this response by calling the handler with an 'abandoned response'
|
||
|
// error.
|
||
|
void abandon() override {
|
||
|
if (auto Err = Handler(this->createAbandonedResponseError())) {
|
||
|
// Handlers should not fail when passed an abandoned response error.
|
||
|
report_fatal_error(std::move(Err));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
HandlerT Handler;
|
||
|
};
|
||
|
|
||
|
template <typename ChannelT, typename FuncRetT, typename HandlerT>
|
||
|
class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
|
||
|
: public ResponseHandler<ChannelT> {
|
||
|
public:
|
||
|
ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
|
||
|
|
||
|
// Handle the result by deserializing it from the channel then passing it
|
||
|
// to the user defined handler.
|
||
|
Error handleResponse(ChannelT &C) override {
|
||
|
using HandlerArgType = typename ResponseHandlerArg<
|
||
|
typename HandlerTraits<HandlerT>::Type>::ArgType;
|
||
|
HandlerArgType Result((typename HandlerArgType::value_type()));
|
||
|
|
||
|
if (auto Err = SerializationTraits<ChannelT, Expected<FuncRetT>,
|
||
|
HandlerArgType>::deserialize(C, Result))
|
||
|
return Err;
|
||
|
if (auto Err = C.endReceiveMessage())
|
||
|
return Err;
|
||
|
return Handler(std::move(Result));
|
||
|
}
|
||
|
|
||
|
// Abandon this response by calling the handler with an 'abandoned response'
|
||
|
// error.
|
||
|
void abandon() override {
|
||
|
if (auto Err = Handler(this->createAbandonedResponseError())) {
|
||
|
// Handlers should not fail when passed an abandoned response error.
|
||
|
report_fatal_error(std::move(Err));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
HandlerT Handler;
|
||
|
};
|
||
|
|
||
|
template <typename ChannelT, typename HandlerT>
|
||
|
class ResponseHandlerImpl<ChannelT, Error, HandlerT>
|
||
|
: public ResponseHandler<ChannelT> {
|
||
|
public:
|
||
|
ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
|
||
|
|
||
|
// Handle the result by deserializing it from the channel then passing it
|
||
|
// to the user defined handler.
|
||
|
Error handleResponse(ChannelT &C) override {
|
||
|
Error Result = Error::success();
|
||
|
if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize(
|
||
|
C, Result)) {
|
||
|
consumeError(std::move(Result));
|
||
|
return Err;
|
||
|
}
|
||
|
if (auto Err = C.endReceiveMessage()) {
|
||
|
consumeError(std::move(Result));
|
||
|
return Err;
|
||
|
}
|
||
|
return Handler(std::move(Result));
|
||
|
}
|
||
|
|
||
|
// Abandon this response by calling the handler with an 'abandoned response'
|
||
|
// error.
|
||
|
void abandon() override {
|
||
|
if (auto Err = Handler(this->createAbandonedResponseError())) {
|
||
|
// Handlers should not fail when passed an abandoned response error.
|
||
|
report_fatal_error(std::move(Err));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
HandlerT Handler;
|
||
|
};
|
||
|
|
||
|
// Create a ResponseHandler from a given user handler.
|
||
|
template <typename ChannelT, typename FuncRetT, typename HandlerT>
|
||
|
std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
|
||
|
return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
|
||
|
std::move(H));
|
||
|
}
|
||
|
|
||
|
// Helper for wrapping member functions up as functors. This is useful for
|
||
|
// installing methods as result handlers.
|
||
|
template <typename ClassT, typename RetT, typename... ArgTs>
|
||
|
class MemberFnWrapper {
|
||
|
public:
|
||
|
using MethodT = RetT (ClassT::*)(ArgTs...);
|
||
|
MemberFnWrapper(ClassT &Instance, MethodT Method)
|
||
|
: Instance(Instance), Method(Method) {}
|
||
|
RetT operator()(ArgTs &&...Args) {
|
||
|
return (Instance.*Method)(std::move(Args)...);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
ClassT &Instance;
|
||
|
MethodT Method;
|
||
|
};
|
||
|
|
||
|
// Helper that provides a Functor for deserializing arguments.
|
||
|
template <typename... ArgTs> class ReadArgs {
|
||
|
public:
|
||
|
Error operator()() { return Error::success(); }
|
||
|
};
|
||
|
|
||
|
template <typename ArgT, typename... ArgTs>
|
||
|
class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
|
||
|
public:
|
||
|
ReadArgs(ArgT &Arg, ArgTs &...Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
|
||
|
|
||
|
Error operator()(ArgT &ArgVal, ArgTs &...ArgVals) {
|
||
|
this->Arg = std::move(ArgVal);
|
||
|
return ReadArgs<ArgTs...>::operator()(ArgVals...);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
ArgT &Arg;
|
||
|
};
|
||
|
|
||
|
// Manage sequence numbers.
|
||
|
template <typename SequenceNumberT> class SequenceNumberManager {
|
||
|
public:
|
||
|
// Reset, making all sequence numbers available.
|
||
|
void reset() {
|
||
|
std::lock_guard<std::mutex> Lock(SeqNoLock);
|
||
|
NextSequenceNumber = 0;
|
||
|
FreeSequenceNumbers.clear();
|
||
|
}
|
||
|
|
||
|
// Get the next available sequence number. Will re-use numbers that have
|
||
|
// been released.
|
||
|
SequenceNumberT getSequenceNumber() {
|
||
|
std::lock_guard<std::mutex> Lock(SeqNoLock);
|
||
|
if (FreeSequenceNumbers.empty())
|
||
|
return NextSequenceNumber++;
|
||
|
auto SequenceNumber = FreeSequenceNumbers.back();
|
||
|
FreeSequenceNumbers.pop_back();
|
||
|
return SequenceNumber;
|
||
|
}
|
||
|
|
||
|
// Release a sequence number, making it available for re-use.
|
||
|
void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
|
||
|
std::lock_guard<std::mutex> Lock(SeqNoLock);
|
||
|
FreeSequenceNumbers.push_back(SequenceNumber);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
std::mutex SeqNoLock;
|
||
|
SequenceNumberT NextSequenceNumber = 0;
|
||
|
std::vector<SequenceNumberT> FreeSequenceNumbers;
|
||
|
};
|
||
|
|
||
|
// Checks that predicate P holds for each corresponding pair of type arguments
|
||
|
// from T1 and T2 tuple.
|
||
|
template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
|
||
|
class RPCArgTypeCheckHelper;
|
||
|
|
||
|
template <template <class, class> class P>
|
||
|
class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
|
||
|
public:
|
||
|
static const bool value = true;
|
||
|
};
|
||
|
|
||
|
template <template <class, class> class P, typename T, typename... Ts,
|
||
|
typename U, typename... Us>
|
||
|
class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
|
||
|
public:
|
||
|
static const bool value =
|
||
|
P<T, U>::value &&
|
||
|
RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
|
||
|
};
|
||
|
|
||
|
template <template <class, class> class P, typename T1Sig, typename T2Sig>
|
||
|
class RPCArgTypeCheck {
|
||
|
public:
|
||
|
using T1Tuple = typename RPCFunctionArgsTuple<T1Sig>::Type;
|
||
|
using T2Tuple = typename RPCFunctionArgsTuple<T2Sig>::Type;
|
||
|
|
||
|
static_assert(std::tuple_size<T1Tuple>::value >=
|
||
|
std::tuple_size<T2Tuple>::value,
|
||
|
"Too many arguments to RPC call");
|
||
|
static_assert(std::tuple_size<T1Tuple>::value <=
|
||
|
std::tuple_size<T2Tuple>::value,
|
||
|
"Too few arguments to RPC call");
|
||
|
|
||
|
static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
|
||
|
};
|
||
|
|
||
|
template <typename ChannelT, typename WireT, typename ConcreteT>
|
||
|
class CanSerialize {
|
||
|
private:
|
||
|
using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
|
||
|
|
||
|
template <typename T>
|
||
|
static std::true_type check(
|
||
|
std::enable_if_t<std::is_same<decltype(T::serialize(
|
||
|
std::declval<ChannelT &>(),
|
||
|
std::declval<const ConcreteT &>())),
|
||
|
Error>::value,
|
||
|
void *>);
|
||
|
|
||
|
template <typename> static std::false_type check(...);
|
||
|
|
||
|
public:
|
||
|
static const bool value = decltype(check<S>(0))::value;
|
||
|
};
|
||
|
|
||
|
template <typename ChannelT, typename WireT, typename ConcreteT>
|
||
|
class CanDeserialize {
|
||
|
private:
|
||
|
using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
|
||
|
|
||
|
template <typename T>
|
||
|
static std::true_type
|
||
|
check(std::enable_if_t<
|
||
|
std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
|
||
|
std::declval<ConcreteT &>())),
|
||
|
Error>::value,
|
||
|
void *>);
|
||
|
|
||
|
template <typename> static std::false_type check(...);
|
||
|
|
||
|
public:
|
||
|
static const bool value = decltype(check<S>(0))::value;
|
||
|
};
|
||
|
|
||
|
/// Contains primitive utilities for defining, calling and handling calls to
|
||
|
/// remote procedures. ChannelT is a bidirectional stream conforming to the
|
||
|
/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
|
||
|
/// identifier type that must be serializable on ChannelT, and SequenceNumberT
|
||
|
/// is an integral type that will be used to number in-flight function calls.
|
||
|
///
|
||
|
/// These utilities support the construction of very primitive RPC utilities.
|
||
|
/// Their intent is to ensure correct serialization and deserialization of
|
||
|
/// procedure arguments, and to keep the client and server's view of the API in
|
||
|
/// sync.
|
||
|
template <typename ImplT, typename ChannelT, typename FunctionIdT,
|
||
|
typename SequenceNumberT>
|
||
|
class RPCEndpointBase {
|
||
|
protected:
|
||
|
class OrcRPCInvalid : public RPCFunction<OrcRPCInvalid, void()> {
|
||
|
public:
|
||
|
static const char *getName() { return "__orc_rpc$invalid"; }
|
||
|
};
|
||
|
|
||
|
class OrcRPCResponse : public RPCFunction<OrcRPCResponse, void()> {
|
||
|
public:
|
||
|
static const char *getName() { return "__orc_rpc$response"; }
|
||
|
};
|
||
|
|
||
|
class OrcRPCNegotiate
|
||
|
: public RPCFunction<OrcRPCNegotiate, FunctionIdT(std::string)> {
|
||
|
public:
|
||
|
static const char *getName() { return "__orc_rpc$negotiate"; }
|
||
|
};
|
||
|
|
||
|
// Helper predicate for testing for the presence of SerializeTraits
|
||
|
// serializers.
|
||
|
template <typename WireT, typename ConcreteT>
|
||
|
class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
|
||
|
public:
|
||
|
using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
|
||
|
|
||
|
static_assert(value, "Missing serializer for argument (Can't serialize the "
|
||
|
"first template type argument of CanSerializeCheck "
|
||
|
"from the second)");
|
||
|
};
|
||
|
|
||
|
// Helper predicate for testing for the presence of SerializeTraits
|
||
|
// deserializers.
|
||
|
template <typename WireT, typename ConcreteT>
|
||
|
class CanDeserializeCheck
|
||
|
: detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
|
||
|
public:
|
||
|
using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
|
||
|
|
||
|
static_assert(value, "Missing deserializer for argument (Can't deserialize "
|
||
|
"the second template type argument of "
|
||
|
"CanDeserializeCheck from the first)");
|
||
|
};
|
||
|
|
||
|
public:
|
||
|
/// Construct an RPC instance on a channel.
|
||
|
RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
|
||
|
: C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
|
||
|
// Hold ResponseId in a special variable, since we expect Response to be
|
||
|
// called relatively frequently, and want to avoid the map lookup.
|
||
|
ResponseId = FnIdAllocator.getResponseId();
|
||
|
RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
|
||
|
|
||
|
// Register the negotiate function id and handler.
|
||
|
auto NegotiateId = FnIdAllocator.getNegotiateId();
|
||
|
RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
|
||
|
Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
|
||
|
[this](const std::string &Name) { return handleNegotiate(Name); });
|
||
|
}
|
||
|
|
||
|
/// Negotiate a function id for Func with the other end of the channel.
|
||
|
template <typename Func> Error negotiateFunction(bool Retry = false) {
|
||
|
return getRemoteFunctionId<Func>(true, Retry).takeError();
|
||
|
}
|
||
|
|
||
|
/// Append a call Func, does not call send on the channel.
|
||
|
/// The first argument specifies a user-defined handler to be run when the
|
||
|
/// function returns. The handler should take an Expected<Func::ReturnType>,
|
||
|
/// or an Error (if Func::ReturnType is void). The handler will be called
|
||
|
/// with an error if the return value is abandoned due to a channel error.
|
||
|
template <typename Func, typename HandlerT, typename... ArgTs>
|
||
|
Error appendCallAsync(HandlerT Handler, const ArgTs &...Args) {
|
||
|
|
||
|
static_assert(
|
||
|
detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
|
||
|
void(ArgTs...)>::value,
|
||
|
"");
|
||
|
|
||
|
// Look up the function ID.
|
||
|
FunctionIdT FnId;
|
||
|
if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
|
||
|
FnId = *FnIdOrErr;
|
||
|
else {
|
||
|
// Negotiation failed. Notify the handler then return the negotiate-failed
|
||
|
// error.
|
||
|
cantFail(Handler(make_error<ResponseAbandoned>()));
|
||
|
return FnIdOrErr.takeError();
|
||
|
}
|
||
|
|
||
|
SequenceNumberT SeqNo; // initialized in locked scope below.
|
||
|
{
|
||
|
// Lock the pending responses map and sequence number manager.
|
||
|
std::lock_guard<std::mutex> Lock(ResponsesMutex);
|
||
|
|
||
|
// Allocate a sequence number.
|
||
|
SeqNo = SequenceNumberMgr.getSequenceNumber();
|
||
|
assert(!PendingResponses.count(SeqNo) &&
|
||
|
"Sequence number already allocated");
|
||
|
|
||
|
// Install the user handler.
|
||
|
PendingResponses[SeqNo] =
|
||
|
detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
|
||
|
std::move(Handler));
|
||
|
}
|
||
|
|
||
|
// Open the function call message.
|
||
|
if (auto Err = C.startSendMessage(FnId, SeqNo)) {
|
||
|
abandonPendingResponses();
|
||
|
return Err;
|
||
|
}
|
||
|
|
||
|
// Serialize the call arguments.
|
||
|
if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
|
||
|
C, Args...)) {
|
||
|
abandonPendingResponses();
|
||
|
return Err;
|
||
|
}
|
||
|
|
||
|
// Close the function call messagee.
|
||
|
if (auto Err = C.endSendMessage()) {
|
||
|
abandonPendingResponses();
|
||
|
return Err;
|
||
|
}
|
||
|
|
||
|
return Error::success();
|
||
|
}
|
||
|
|
||
|
Error sendAppendedCalls() { return C.send(); };
|
||
|
|
||
|
template <typename Func, typename HandlerT, typename... ArgTs>
|
||
|
Error callAsync(HandlerT Handler, const ArgTs &...Args) {
|
||
|
if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
|
||
|
return Err;
|
||
|
return C.send();
|
||
|
}
|
||
|
|
||
|
/// Handle one incoming call.
|
||
|
Error handleOne() {
|
||
|
FunctionIdT FnId;
|
||
|
SequenceNumberT SeqNo;
|
||
|
if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
|
||
|
abandonPendingResponses();
|
||
|
return Err;
|
||
|
}
|
||
|
if (FnId == ResponseId)
|
||
|
return handleResponse(SeqNo);
|
||
|
auto I = Handlers.find(FnId);
|
||
|
if (I != Handlers.end())
|
||
|
return I->second(C, SeqNo);
|
||
|
|
||
|
// else: No handler found. Report error to client?
|
||
|
return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
|
||
|
SeqNo);
|
||
|
}
|
||
|
|
||
|
/// Helper for handling setter procedures - this method returns a functor that
|
||
|
/// sets the variables referred to by Args... to values deserialized from the
|
||
|
/// channel.
|
||
|
/// E.g.
|
||
|
///
|
||
|
/// typedef Function<0, bool, int> Func1;
|
||
|
///
|
||
|
/// ...
|
||
|
/// bool B;
|
||
|
/// int I;
|
||
|
/// if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
|
||
|
/// /* Handle Args */ ;
|
||
|
///
|
||
|
template <typename... ArgTs>
|
||
|
static detail::ReadArgs<ArgTs...> readArgs(ArgTs &...Args) {
|
||
|
return detail::ReadArgs<ArgTs...>(Args...);
|
||
|
}
|
||
|
|
||
|
/// Abandon all outstanding result handlers.
|
||
|
///
|
||
|
/// This will call all currently registered result handlers to receive an
|
||
|
/// "abandoned" error as their argument. This is used internally by the RPC
|
||
|
/// in error situations, but can also be called directly by clients who are
|
||
|
/// disconnecting from the remote and don't or can't expect responses to their
|
||
|
/// outstanding calls. (Especially for outstanding blocking calls, calling
|
||
|
/// this function may be necessary to avoid dead threads).
|
||
|
void abandonPendingResponses() {
|
||
|
// Lock the pending responses map and sequence number manager.
|
||
|
std::lock_guard<std::mutex> Lock(ResponsesMutex);
|
||
|
|
||
|
for (auto &KV : PendingResponses)
|
||
|
KV.second->abandon();
|
||
|
PendingResponses.clear();
|
||
|
SequenceNumberMgr.reset();
|
||
|
}
|
||
|
|
||
|
/// Remove the handler for the given function.
|
||
|
/// A handler must currently be registered for this function.
|
||
|
template <typename Func> void removeHandler() {
|
||
|
auto IdItr = LocalFunctionIds.find(Func::getPrototype());
|
||
|
assert(IdItr != LocalFunctionIds.end() &&
|
||
|
"Function does not have a registered handler");
|
||
|
auto HandlerItr = Handlers.find(IdItr->second);
|
||
|
assert(HandlerItr != Handlers.end() &&
|
||
|
"Function does not have a registered handler");
|
||
|
Handlers.erase(HandlerItr);
|
||
|
}
|
||
|
|
||
|
/// Clear all handlers.
|
||
|
void clearHandlers() { Handlers.clear(); }
|
||
|
|
||
|
protected:
|
||
|
FunctionIdT getInvalidFunctionId() const {
|
||
|
return FnIdAllocator.getInvalidId();
|
||
|
}
|
||
|
|
||
|
/// Add the given handler to the handler map and make it available for
|
||
|
/// autonegotiation and execution.
|
||
|
template <typename Func, typename HandlerT>
|
||
|
void addHandlerImpl(HandlerT Handler) {
|
||
|
|
||
|
static_assert(detail::RPCArgTypeCheck<
|
||
|
CanDeserializeCheck, typename Func::Type,
|
||
|
typename detail::HandlerTraits<HandlerT>::Type>::value,
|
||
|
"");
|
||
|
|
||
|
FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
|
||
|
LocalFunctionIds[Func::getPrototype()] = NewFnId;
|
||
|
Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
|
||
|
}
|
||
|
|
||
|
template <typename Func, typename HandlerT>
|
||
|
void addAsyncHandlerImpl(HandlerT Handler) {
|
||
|
|
||
|
static_assert(
|
||
|
detail::RPCArgTypeCheck<
|
||
|
CanDeserializeCheck, typename Func::Type,
|
||
|
typename detail::AsyncHandlerTraits<
|
||
|
typename detail::HandlerTraits<HandlerT>::Type>::Type>::value,
|
||
|
"");
|
||
|
|
||
|
FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
|
||
|
LocalFunctionIds[Func::getPrototype()] = NewFnId;
|
||
|
Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
|
||
|
}
|
||
|
|
||
|
Error handleResponse(SequenceNumberT SeqNo) {
|
||
|
using Handler = typename decltype(PendingResponses)::mapped_type;
|
||
|
Handler PRHandler;
|
||
|
|
||
|
{
|
||
|
// Lock the pending responses map and sequence number manager.
|
||
|
std::unique_lock<std::mutex> Lock(ResponsesMutex);
|
||
|
auto I = PendingResponses.find(SeqNo);
|
||
|
|
||
|
if (I != PendingResponses.end()) {
|
||
|
PRHandler = std::move(I->second);
|
||
|
PendingResponses.erase(I);
|
||
|
SequenceNumberMgr.releaseSequenceNumber(SeqNo);
|
||
|
} else {
|
||
|
// Unlock the pending results map to prevent recursive lock.
|
||
|
Lock.unlock();
|
||
|
abandonPendingResponses();
|
||
|
return make_error<InvalidSequenceNumberForResponse<SequenceNumberT>>(
|
||
|
SeqNo);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
assert(PRHandler &&
|
||
|
"If we didn't find a response handler we should have bailed out");
|
||
|
|
||
|
if (auto Err = PRHandler->handleResponse(C)) {
|
||
|
abandonPendingResponses();
|
||
|
return Err;
|
||
|
}
|
||
|
|
||
|
return Error::success();
|
||
|
}
|
||
|
|
||
|
FunctionIdT handleNegotiate(const std::string &Name) {
|
||
|
auto I = LocalFunctionIds.find(Name);
|
||
|
if (I == LocalFunctionIds.end())
|
||
|
return getInvalidFunctionId();
|
||
|
return I->second;
|
||
|
}
|
||
|
|
||
|
// Find the remote FunctionId for the given function.
|
||
|
template <typename Func>
|
||
|
Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
|
||
|
bool NegotiateIfInvalid) {
|
||
|
bool DoNegotiate;
|
||
|
|
||
|
// Check if we already have a function id...
|
||
|
auto I = RemoteFunctionIds.find(Func::getPrototype());
|
||
|
if (I != RemoteFunctionIds.end()) {
|
||
|
// If it's valid there's nothing left to do.
|
||
|
if (I->second != getInvalidFunctionId())
|
||
|
return I->second;
|
||
|
DoNegotiate = NegotiateIfInvalid;
|
||
|
} else
|
||
|
DoNegotiate = NegotiateIfNotInMap;
|
||
|
|
||
|
// We don't have a function id for Func yet, but we're allowed to try to
|
||
|
// negotiate one.
|
||
|
if (DoNegotiate) {
|
||
|
auto &Impl = static_cast<ImplT &>(*this);
|
||
|
if (auto RemoteIdOrErr =
|
||
|
Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
|
||
|
RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
|
||
|
if (*RemoteIdOrErr == getInvalidFunctionId())
|
||
|
return make_error<CouldNotNegotiate>(Func::getPrototype());
|
||
|
return *RemoteIdOrErr;
|
||
|
} else
|
||
|
return RemoteIdOrErr.takeError();
|
||
|
}
|
||
|
|
||
|
// No key was available in the map and we weren't allowed to try to
|
||
|
// negotiate one, so return an unknown function error.
|
||
|
return make_error<CouldNotNegotiate>(Func::getPrototype());
|
||
|
}
|
||
|
|
||
|
using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
|
||
|
|
||
|
// Wrap the given user handler in the necessary argument-deserialization code,
|
||
|
// result-serialization code, and call to the launch policy (if present).
|
||
|
template <typename Func, typename HandlerT>
|
||
|
WrappedHandlerFn wrapHandler(HandlerT Handler) {
|
||
|
return [this, Handler](ChannelT &Channel,
|
||
|
SequenceNumberT SeqNo) mutable -> Error {
|
||
|
// Start by deserializing the arguments.
|
||
|
using ArgsTuple = typename detail::RPCFunctionArgsTuple<
|
||
|
typename detail::HandlerTraits<HandlerT>::Type>::Type;
|
||
|
auto Args = std::make_shared<ArgsTuple>();
|
||
|
|
||
|
if (auto Err =
|
||
|
detail::HandlerTraits<typename Func::Type>::deserializeArgs(
|
||
|
Channel, *Args))
|
||
|
return Err;
|
||
|
|
||
|
// GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
|
||
|
// for RPCArgs. Void cast RPCArgs to work around this for now.
|
||
|
// FIXME: Remove this workaround once we can assume a working GCC version.
|
||
|
(void)Args;
|
||
|
|
||
|
// End receieve message, unlocking the channel for reading.
|
||
|
if (auto Err = Channel.endReceiveMessage())
|
||
|
return Err;
|
||
|
|
||
|
using HTraits = detail::HandlerTraits<HandlerT>;
|
||
|
using FuncReturn = typename Func::ReturnType;
|
||
|
return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
|
||
|
HTraits::unpackAndRun(Handler, *Args));
|
||
|
};
|
||
|
}
|
||
|
|
||
|
// Wrap the given user handler in the necessary argument-deserialization code,
|
||
|
// result-serialization code, and call to the launch policy (if present).
|
||
|
template <typename Func, typename HandlerT>
|
||
|
WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
|
||
|
return [this, Handler](ChannelT &Channel,
|
||
|
SequenceNumberT SeqNo) mutable -> Error {
|
||
|
// Start by deserializing the arguments.
|
||
|
using AHTraits = detail::AsyncHandlerTraits<
|
||
|
typename detail::HandlerTraits<HandlerT>::Type>;
|
||
|
using ArgsTuple =
|
||
|
typename detail::RPCFunctionArgsTuple<typename AHTraits::Type>::Type;
|
||
|
auto Args = std::make_shared<ArgsTuple>();
|
||
|
|
||
|
if (auto Err =
|
||
|
detail::HandlerTraits<typename Func::Type>::deserializeArgs(
|
||
|
Channel, *Args))
|
||
|
return Err;
|
||
|
|
||
|
// GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
|
||
|
// for RPCArgs. Void cast RPCArgs to work around this for now.
|
||
|
// FIXME: Remove this workaround once we can assume a working GCC version.
|
||
|
(void)Args;
|
||
|
|
||
|
// End receieve message, unlocking the channel for reading.
|
||
|
if (auto Err = Channel.endReceiveMessage())
|
||
|
return Err;
|
||
|
|
||
|
using HTraits = detail::HandlerTraits<HandlerT>;
|
||
|
using FuncReturn = typename Func::ReturnType;
|
||
|
auto Responder = [this,
|
||
|
SeqNo](typename AHTraits::ResultType RetVal) -> Error {
|
||
|
return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
|
||
|
std::move(RetVal));
|
||
|
};
|
||
|
|
||
|
return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
|
||
|
};
|
||
|
}
|
||
|
|
||
|
ChannelT &C;
|
||
|
|
||
|
bool LazyAutoNegotiation;
|
||
|
|
||
|
RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
|
||
|
|
||
|
FunctionIdT ResponseId;
|
||
|
std::map<std::string, FunctionIdT> LocalFunctionIds;
|
||
|
std::map<const char *, FunctionIdT> RemoteFunctionIds;
|
||
|
|
||
|
std::map<FunctionIdT, WrappedHandlerFn> Handlers;
|
||
|
|
||
|
std::mutex ResponsesMutex;
|
||
|
detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
|
||
|
std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
|
||
|
PendingResponses;
|
||
|
};
|
||
|
|
||
|
} // end namespace detail
|
||
|
|
||
|
template <typename ChannelT, typename FunctionIdT = uint32_t,
|
||
|
typename SequenceNumberT = uint32_t>
|
||
|
class MultiThreadedRPCEndpoint
|
||
|
: public detail::RPCEndpointBase<
|
||
|
MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
|
||
|
ChannelT, FunctionIdT, SequenceNumberT> {
|
||
|
private:
|
||
|
using BaseClass = detail::RPCEndpointBase<
|
||
|
MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
|
||
|
ChannelT, FunctionIdT, SequenceNumberT>;
|
||
|
|
||
|
public:
|
||
|
MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
|
||
|
: BaseClass(C, LazyAutoNegotiation) {}
|
||
|
|
||
|
/// Add a handler for the given RPC function.
|
||
|
/// This installs the given handler functor for the given RPCFunction, and
|
||
|
/// makes the RPC function available for negotiation/calling from the remote.
|
||
|
template <typename Func, typename HandlerT>
|
||
|
void addHandler(HandlerT Handler) {
|
||
|
return this->template addHandlerImpl<Func>(std::move(Handler));
|
||
|
}
|
||
|
|
||
|
/// Add a class-method as a handler.
|
||
|
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
|
||
|
void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
|
||
|
addHandler<Func>(
|
||
|
detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
|
||
|
}
|
||
|
|
||
|
template <typename Func, typename HandlerT>
|
||
|
void addAsyncHandler(HandlerT Handler) {
|
||
|
return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
|
||
|
}
|
||
|
|
||
|
/// Add a class-method as a handler.
|
||
|
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
|
||
|
void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
|
||
|
addAsyncHandler<Func>(
|
||
|
detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
|
||
|
}
|
||
|
|
||
|
/// Return type for non-blocking call primitives.
|
||
|
template <typename Func>
|
||
|
using NonBlockingCallResult = typename detail::ResultTraits<
|
||
|
typename Func::ReturnType>::ReturnFutureType;
|
||
|
|
||
|
/// Call Func on Channel C. Does not block, does not call send. Returns a pair
|
||
|
/// of a future result and the sequence number assigned to the result.
|
||
|
///
|
||
|
/// This utility function is primarily used for single-threaded mode support,
|
||
|
/// where the sequence number can be used to wait for the corresponding
|
||
|
/// result. In multi-threaded mode the appendCallNB method, which does not
|
||
|
/// return the sequence numeber, should be preferred.
|
||
|
template <typename Func, typename... ArgTs>
|
||
|
Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &...Args) {
|
||
|
using RTraits = detail::ResultTraits<typename Func::ReturnType>;
|
||
|
using ErrorReturn = typename RTraits::ErrorReturnType;
|
||
|
using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
|
||
|
|
||
|
ErrorReturnPromise Promise;
|
||
|
auto FutureResult = Promise.get_future();
|
||
|
|
||
|
if (auto Err = this->template appendCallAsync<Func>(
|
||
|
[Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable {
|
||
|
Promise.set_value(std::move(RetOrErr));
|
||
|
return Error::success();
|
||
|
},
|
||
|
Args...)) {
|
||
|
RTraits::consumeAbandoned(FutureResult.get());
|
||
|
return std::move(Err);
|
||
|
}
|
||
|
return std::move(FutureResult);
|
||
|
}
|
||
|
|
||
|
/// The same as appendCallNBWithSeq, except that it calls C.send() to
|
||
|
/// flush the channel after serializing the call.
|
||
|
template <typename Func, typename... ArgTs>
|
||
|
Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &...Args) {
|
||
|
auto Result = appendCallNB<Func>(Args...);
|
||
|
if (!Result)
|
||
|
return Result;
|
||
|
if (auto Err = this->C.send()) {
|
||
|
this->abandonPendingResponses();
|
||
|
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
|
||
|
std::move(Result->get()));
|
||
|
return std::move(Err);
|
||
|
}
|
||
|
return Result;
|
||
|
}
|
||
|
|
||
|
/// Call Func on Channel C. Blocks waiting for a result. Returns an Error
|
||
|
/// for void functions or an Expected<T> for functions returning a T.
|
||
|
///
|
||
|
/// This function is for use in threaded code where another thread is
|
||
|
/// handling responses and incoming calls.
|
||
|
template <typename Func, typename... ArgTs,
|
||
|
typename AltRetT = typename Func::ReturnType>
|
||
|
typename detail::ResultTraits<AltRetT>::ErrorReturnType
|
||
|
callB(const ArgTs &...Args) {
|
||
|
if (auto FutureResOrErr = callNB<Func>(Args...))
|
||
|
return FutureResOrErr->get();
|
||
|
else
|
||
|
return FutureResOrErr.takeError();
|
||
|
}
|
||
|
|
||
|
/// Handle incoming RPC calls.
|
||
|
Error handlerLoop() {
|
||
|
while (true)
|
||
|
if (auto Err = this->handleOne())
|
||
|
return Err;
|
||
|
return Error::success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <typename ChannelT, typename FunctionIdT = uint32_t,
|
||
|
typename SequenceNumberT = uint32_t>
|
||
|
class SingleThreadedRPCEndpoint
|
||
|
: public detail::RPCEndpointBase<
|
||
|
SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
|
||
|
ChannelT, FunctionIdT, SequenceNumberT> {
|
||
|
private:
|
||
|
using BaseClass = detail::RPCEndpointBase<
|
||
|
SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
|
||
|
ChannelT, FunctionIdT, SequenceNumberT>;
|
||
|
|
||
|
public:
|
||
|
SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
|
||
|
: BaseClass(C, LazyAutoNegotiation) {}
|
||
|
|
||
|
template <typename Func, typename HandlerT>
|
||
|
void addHandler(HandlerT Handler) {
|
||
|
return this->template addHandlerImpl<Func>(std::move(Handler));
|
||
|
}
|
||
|
|
||
|
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
|
||
|
void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
|
||
|
addHandler<Func>(
|
||
|
detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
|
||
|
}
|
||
|
|
||
|
template <typename Func, typename HandlerT>
|
||
|
void addAsyncHandler(HandlerT Handler) {
|
||
|
return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
|
||
|
}
|
||
|
|
||
|
/// Add a class-method as a handler.
|
||
|
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
|
||
|
void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
|
||
|
addAsyncHandler<Func>(
|
||
|
detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
|
||
|
}
|
||
|
|
||
|
template <typename Func, typename... ArgTs,
|
||
|
typename AltRetT = typename Func::ReturnType>
|
||
|
typename detail::ResultTraits<AltRetT>::ErrorReturnType
|
||
|
callB(const ArgTs &...Args) {
|
||
|
bool ReceivedResponse = false;
|
||
|
using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
|
||
|
auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
|
||
|
|
||
|
// We have to 'Check' result (which we know is in a success state at this
|
||
|
// point) so that it can be overwritten in the async handler.
|
||
|
(void)!!Result;
|
||
|
|
||
|
if (auto Err = this->template appendCallAsync<Func>(
|
||
|
[&](ResultType R) {
|
||
|
Result = std::move(R);
|
||
|
ReceivedResponse = true;
|
||
|
return Error::success();
|
||
|
},
|
||
|
Args...)) {
|
||
|
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
|
||
|
std::move(Result));
|
||
|
return std::move(Err);
|
||
|
}
|
||
|
|
||
|
if (auto Err = this->C.send()) {
|
||
|
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
|
||
|
std::move(Result));
|
||
|
return std::move(Err);
|
||
|
}
|
||
|
|
||
|
while (!ReceivedResponse) {
|
||
|
if (auto Err = this->handleOne()) {
|
||
|
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
|
||
|
std::move(Result));
|
||
|
return std::move(Err);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return Result;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Asynchronous dispatch for a function on an RPC endpoint.
|
||
|
template <typename RPCClass, typename Func> class RPCAsyncDispatch {
|
||
|
public:
|
||
|
RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
|
||
|
|
||
|
template <typename HandlerT, typename... ArgTs>
|
||
|
Error operator()(HandlerT Handler, const ArgTs &...Args) const {
|
||
|
return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
RPCClass &Endpoint;
|
||
|
};
|
||
|
|
||
|
/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
|
||
|
template <typename Func, typename RPCEndpointT>
|
||
|
RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
|
||
|
return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
|
||
|
}
|
||
|
|
||
|
/// Allows a set of asynchrounous calls to be dispatched, and then
|
||
|
/// waited on as a group.
|
||
|
class ParallelCallGroup {
|
||
|
public:
|
||
|
ParallelCallGroup() = default;
|
||
|
ParallelCallGroup(const ParallelCallGroup &) = delete;
|
||
|
ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
|
||
|
|
||
|
/// Make as asynchronous call.
|
||
|
template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
|
||
|
Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
|
||
|
const ArgTs &...Args) {
|
||
|
// Increment the count of outstanding calls. This has to happen before
|
||
|
// we invoke the call, as the handler may (depending on scheduling)
|
||
|
// be run immediately on another thread, and we don't want the decrement
|
||
|
// in the wrapped handler below to run before the increment.
|
||
|
{
|
||
|
std::unique_lock<std::mutex> Lock(M);
|
||
|
++NumOutstandingCalls;
|
||
|
}
|
||
|
|
||
|
// Wrap the user handler in a lambda that will decrement the
|
||
|
// outstanding calls count, then poke the condition variable.
|
||
|
using ArgType = typename detail::ResponseHandlerArg<
|
||
|
typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
|
||
|
auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) {
|
||
|
auto Err = Handler(std::move(Arg));
|
||
|
std::unique_lock<std::mutex> Lock(M);
|
||
|
--NumOutstandingCalls;
|
||
|
CV.notify_all();
|
||
|
return Err;
|
||
|
};
|
||
|
|
||
|
return AsyncDispatch(std::move(WrappedHandler), Args...);
|
||
|
}
|
||
|
|
||
|
/// Blocks until all calls have been completed and their return value
|
||
|
/// handlers run.
|
||
|
void wait() {
|
||
|
std::unique_lock<std::mutex> Lock(M);
|
||
|
while (NumOutstandingCalls > 0)
|
||
|
CV.wait(Lock);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
std::mutex M;
|
||
|
std::condition_variable CV;
|
||
|
uint32_t NumOutstandingCalls = 0;
|
||
|
};
|
||
|
|
||
|
/// Convenience class for grouping RPCFunctions into APIs that can be
|
||
|
/// negotiated as a block.
|
||
|
///
|
||
|
template <typename... Funcs> class APICalls {
|
||
|
public:
|
||
|
/// Test whether this API contains Function F.
|
||
|
template <typename F> class Contains {
|
||
|
public:
|
||
|
static const bool value = false;
|
||
|
};
|
||
|
|
||
|
/// Negotiate all functions in this API.
|
||
|
template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
|
||
|
return Error::success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <typename Func, typename... Funcs> class APICalls<Func, Funcs...> {
|
||
|
public:
|
||
|
template <typename F> class Contains {
|
||
|
public:
|
||
|
static const bool value = std::is_same<F, Func>::value |
|
||
|
APICalls<Funcs...>::template Contains<F>::value;
|
||
|
};
|
||
|
|
||
|
template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
|
||
|
if (auto Err = R.template negotiateFunction<Func>())
|
||
|
return Err;
|
||
|
return APICalls<Funcs...>::negotiate(R);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <typename... InnerFuncs, typename... Funcs>
|
||
|
class APICalls<APICalls<InnerFuncs...>, Funcs...> {
|
||
|
public:
|
||
|
template <typename F> class Contains {
|
||
|
public:
|
||
|
static const bool value =
|
||
|
APICalls<InnerFuncs...>::template Contains<F>::value |
|
||
|
APICalls<Funcs...>::template Contains<F>::value;
|
||
|
};
|
||
|
|
||
|
template <typename RPCEndpoint> static Error negotiate(RPCEndpoint &R) {
|
||
|
if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
|
||
|
return Err;
|
||
|
return APICalls<Funcs...>::negotiate(R);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // end namespace shared
|
||
|
} // end namespace orc
|
||
|
} // end namespace llvm
|
||
|
|
||
|
#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RPCUTILS_H
|