cmif_serialization: support non-domain sessions on domain servers

This commit is contained in:
Liam 2024-01-25 21:26:44 -05:00
parent 431df5ae93
commit a774ff935c
1 changed files with 32 additions and 33 deletions

View File

@ -97,20 +97,20 @@ constexpr RequestLayout GetDomainReplyOutLayout() {
}; };
} }
template <bool Domain, typename MethodArguments> template <typename MethodArguments>
constexpr RequestLayout GetReplyInLayout() { constexpr RequestLayout GetReplyInLayout(bool is_domain) {
return Domain ? GetDomainReplyInLayout<MethodArguments>() : GetNonDomainReplyInLayout<MethodArguments>(); return is_domain ? GetDomainReplyInLayout<MethodArguments>() : GetNonDomainReplyInLayout<MethodArguments>();
} }
template <bool Domain, typename MethodArguments> template <typename MethodArguments>
constexpr RequestLayout GetReplyOutLayout() { constexpr RequestLayout GetReplyOutLayout(bool is_domain) {
return Domain ? GetDomainReplyOutLayout<MethodArguments>() : GetNonDomainReplyOutLayout<MethodArguments>(); return is_domain ? GetDomainReplyOutLayout<MethodArguments>() : GetNonDomainReplyOutLayout<MethodArguments>();
} }
using OutTemporaryBuffers = std::array<Common::ScratchBuffer<u8>, 3>; using OutTemporaryBuffers = std::array<Common::ScratchBuffer<u8>, 3>;
template <bool Domain, typename MethodArguments, typename CallArguments, size_t PrevAlign = 1, size_t DataOffset = 0, size_t HandleIndex = 0, size_t InBufferIndex = 0, size_t OutBufferIndex = 0, bool RawDataFinished = false, size_t ArgIndex = 0> template <typename MethodArguments, typename CallArguments, size_t PrevAlign = 1, size_t DataOffset = 0, size_t HandleIndex = 0, size_t InBufferIndex = 0, size_t OutBufferIndex = 0, bool RawDataFinished = false, size_t ArgIndex = 0>
void ReadInArgument(CallArguments& args, const u8* raw_data, HLERequestContext& ctx, OutTemporaryBuffers& temp) { void ReadInArgument(bool is_domain, CallArguments& args, const u8* raw_data, HLERequestContext& ctx, OutTemporaryBuffers& temp) {
if constexpr (ArgIndex >= std::tuple_size_v<CallArguments>) { if constexpr (ArgIndex >= std::tuple_size_v<CallArguments>) {
return; return;
} else { } else {
@ -134,25 +134,25 @@ void ReadInArgument(CallArguments& args, const u8* raw_data, HLERequestContext&
std::memcpy(&std::get<ArgIndex>(args), raw_data + ArgOffset, ArgSize); std::memcpy(&std::get<ArgIndex>(args), raw_data + ArgOffset, ArgSize);
} }
return ReadInArgument<Domain, MethodArguments, CallArguments, ArgAlign, ArgEnd, HandleIndex, InBufferIndex, OutBufferIndex, false, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, ArgAlign, ArgEnd, HandleIndex, InBufferIndex, OutBufferIndex, false, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InInterface) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InInterface) {
constexpr size_t ArgAlign = alignof(u32); constexpr size_t ArgAlign = alignof(u32);
constexpr size_t ArgSize = sizeof(u32); constexpr size_t ArgSize = sizeof(u32);
constexpr size_t ArgOffset = Common::AlignUp(DataOffset, ArgAlign); constexpr size_t ArgOffset = Common::AlignUp(DataOffset, ArgAlign);
constexpr size_t ArgEnd = ArgOffset + ArgSize; constexpr size_t ArgEnd = ArgOffset + ArgSize;
static_assert(Domain); ASSERT(is_domain);
ASSERT(ctx.GetDomainMessageHeader().input_object_count > 0); ASSERT(ctx.GetDomainMessageHeader().input_object_count > 0);
u32 value{}; u32 value{};
std::memcpy(&value, raw_data + ArgOffset, ArgSize); std::memcpy(&value, raw_data + ArgOffset, ArgSize);
std::get<ArgIndex>(args) = ctx.GetDomainHandler<ArgType::Type>(value - 1); std::get<ArgIndex>(args) = ctx.GetDomainHandler<ArgType::Type>(value - 1);
return ReadInArgument<Domain, MethodArguments, CallArguments, ArgAlign, ArgEnd, HandleIndex, InBufferIndex, OutBufferIndex, true, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, ArgAlign, ArgEnd, HandleIndex, InBufferIndex, OutBufferIndex, true, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InCopyHandle) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InCopyHandle) {
std::get<ArgIndex>(args) = ctx.GetObjectFromHandle<typename ArgType::Type>(ctx.GetCopyHandle(HandleIndex)).GetPointerUnsafe(); std::get<ArgIndex>(args) = ctx.GetObjectFromHandle<typename ArgType::Type>(ctx.GetCopyHandle(HandleIndex)).GetPointerUnsafe();
return ReadInArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex + 1, InBufferIndex, OutBufferIndex, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex + 1, InBufferIndex, OutBufferIndex, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InLargeData) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InLargeData) {
constexpr size_t BufferSize = sizeof(ArgType); constexpr size_t BufferSize = sizeof(ArgType);
@ -172,7 +172,7 @@ void ReadInArgument(CallArguments& args, const u8* raw_data, HLERequestContext&
std::memcpy(&std::get<ArgIndex>(args), buffer.data(), std::min(BufferSize, buffer.size())); std::memcpy(&std::get<ArgIndex>(args), buffer.data(), std::min(BufferSize, buffer.size()));
return ReadInArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex + 1, OutBufferIndex, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex + 1, OutBufferIndex, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InBuffer) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::InBuffer) {
using ElementType = typename ArgType::Type; using ElementType = typename ArgType::Type;
@ -193,14 +193,14 @@ void ReadInArgument(CallArguments& args, const u8* raw_data, HLERequestContext&
std::get<ArgIndex>(args) = std::span(ptr, size); std::get<ArgIndex>(args) = std::span(ptr, size);
return ReadInArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex + 1, OutBufferIndex, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex + 1, OutBufferIndex, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutLargeData) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutLargeData) {
constexpr size_t BufferSize = sizeof(ArgType); constexpr size_t BufferSize = sizeof(ArgType);
// Clear the existing data. // Clear the existing data.
std::memset(&std::get<ArgIndex>(args), 0, BufferSize); std::memset(&std::get<ArgIndex>(args), 0, BufferSize);
return ReadInArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutBuffer) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutBuffer) {
using ElementType = typename ArgType::Type; using ElementType = typename ArgType::Type;
@ -217,15 +217,15 @@ void ReadInArgument(CallArguments& args, const u8* raw_data, HLERequestContext&
std::get<ArgIndex>(args) = std::span(ptr, size); std::get<ArgIndex>(args) = std::span(ptr, size);
return ReadInArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else { } else {
return ReadInArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex, OutBufferIndex, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return ReadInArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, HandleIndex, InBufferIndex, OutBufferIndex, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} }
} }
} }
template <bool Domain, typename MethodArguments, typename CallArguments, size_t PrevAlign = 1, size_t DataOffset = 0, size_t OutBufferIndex = 0, bool RawDataFinished = false, size_t ArgIndex = 0> template <typename MethodArguments, typename CallArguments, size_t PrevAlign = 1, size_t DataOffset = 0, size_t OutBufferIndex = 0, bool RawDataFinished = false, size_t ArgIndex = 0>
void WriteOutArgument(CallArguments& args, u8* raw_data, HLERequestContext& ctx, OutTemporaryBuffers& temp) { void WriteOutArgument(bool is_domain, CallArguments& args, u8* raw_data, HLERequestContext& ctx, OutTemporaryBuffers& temp) {
if constexpr (ArgIndex >= std::tuple_size_v<CallArguments>) { if constexpr (ArgIndex >= std::tuple_size_v<CallArguments>) {
return; return;
} else { } else {
@ -243,23 +243,23 @@ void WriteOutArgument(CallArguments& args, u8* raw_data, HLERequestContext& ctx,
std::memcpy(raw_data + ArgOffset, &std::get<ArgIndex>(args), ArgSize); std::memcpy(raw_data + ArgOffset, &std::get<ArgIndex>(args), ArgSize);
return WriteOutArgument<Domain, MethodArguments, CallArguments, ArgAlign, ArgEnd, OutBufferIndex, false, ArgIndex + 1>(args, raw_data, ctx, temp); return WriteOutArgument<MethodArguments, CallArguments, ArgAlign, ArgEnd, OutBufferIndex, false, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutInterface) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutInterface) {
if constexpr (Domain) { if (is_domain) {
ctx.AddDomainObject(std::get<ArgIndex>(args)); ctx.AddDomainObject(std::get<ArgIndex>(args));
} else { } else {
ctx.AddMoveInterface(std::get<ArgIndex>(args)); ctx.AddMoveInterface(std::get<ArgIndex>(args));
} }
return WriteOutArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, true, ArgIndex + 1>(args, raw_data, ctx, temp); return WriteOutArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, true, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutCopyHandle) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutCopyHandle) {
ctx.AddCopyObject(std::get<ArgIndex>(args)); ctx.AddCopyObject(std::get<ArgIndex>(args));
return WriteOutArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return WriteOutArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutMoveHandle) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutMoveHandle) {
ctx.AddMoveObject(std::get<ArgIndex>(args)); ctx.AddMoveObject(std::get<ArgIndex>(args));
return WriteOutArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return WriteOutArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutLargeData) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutLargeData) {
constexpr size_t BufferSize = sizeof(ArgType); constexpr size_t BufferSize = sizeof(ArgType);
@ -272,7 +272,7 @@ void WriteOutArgument(CallArguments& args, u8* raw_data, HLERequestContext& ctx,
ctx.WriteBufferC(&std::get<ArgIndex>(args), BufferSize, OutBufferIndex); ctx.WriteBufferC(&std::get<ArgIndex>(args), BufferSize, OutBufferIndex);
} }
return WriteOutArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return WriteOutArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutBuffer) { } else if constexpr (ArgumentTraits<ArgType>::Type == ArgumentType::OutBuffer) {
auto& buffer = temp[OutBufferIndex]; auto& buffer = temp[OutBufferIndex];
const size_t size = buffer.size(); const size_t size = buffer.size();
@ -287,9 +287,9 @@ void WriteOutArgument(CallArguments& args, u8* raw_data, HLERequestContext& ctx,
} }
} }
return WriteOutArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>( args, raw_data, ctx, temp); return WriteOutArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex + 1, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} else { } else {
return WriteOutArgument<Domain, MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, RawDataFinished, ArgIndex + 1>(args, raw_data, ctx, temp); return WriteOutArgument<MethodArguments, CallArguments, PrevAlign, DataOffset, OutBufferIndex, RawDataFinished, ArgIndex + 1>(is_domain, args, raw_data, ctx, temp);
} }
} }
} }
@ -297,11 +297,10 @@ void WriteOutArgument(CallArguments& args, u8* raw_data, HLERequestContext& ctx,
template <bool Domain, typename T, typename... A> template <bool Domain, typename T, typename... A>
void CmifReplyWrapImpl(HLERequestContext& ctx, T& t, Result (T::*f)(A...)) { void CmifReplyWrapImpl(HLERequestContext& ctx, T& t, Result (T::*f)(A...)) {
// Verify domain state. // Verify domain state.
if constexpr (Domain) { if constexpr (!Domain) {
ASSERT_MSG(ctx.GetManager()->IsDomain(), "Domain reply used on non-domain session");
} else {
ASSERT_MSG(!ctx.GetManager()->IsDomain(), "Non-domain reply used on domain session"); ASSERT_MSG(!ctx.GetManager()->IsDomain(), "Non-domain reply used on domain session");
} }
const bool is_domain = Domain ? ctx.GetManager()->IsDomain() : false;
using MethodArguments = std::tuple<std::remove_reference_t<A>...>; using MethodArguments = std::tuple<std::remove_reference_t<A>...>;
@ -310,7 +309,7 @@ void CmifReplyWrapImpl(HLERequestContext& ctx, T& t, Result (T::*f)(A...)) {
// Read inputs. // Read inputs.
const size_t offset_plus_command_id = ctx.GetDataPayloadOffset() + 2; const size_t offset_plus_command_id = ctx.GetDataPayloadOffset() + 2;
ReadInArgument<Domain, MethodArguments>(call_arguments, reinterpret_cast<u8*>(ctx.CommandBuffer() + offset_plus_command_id), ctx, buffers); ReadInArgument<MethodArguments>(is_domain, call_arguments, reinterpret_cast<u8*>(ctx.CommandBuffer() + offset_plus_command_id), ctx, buffers);
// Call. // Call.
const auto Callable = [&]<typename... CallArgs>(CallArgs&... args) { const auto Callable = [&]<typename... CallArgs>(CallArgs&... args) {
@ -319,12 +318,12 @@ void CmifReplyWrapImpl(HLERequestContext& ctx, T& t, Result (T::*f)(A...)) {
const Result res = std::apply(Callable, call_arguments); const Result res = std::apply(Callable, call_arguments);
// Write result. // Write result.
constexpr RequestLayout layout = GetReplyOutLayout<Domain, MethodArguments>(); const RequestLayout layout = GetReplyOutLayout<MethodArguments>(is_domain);
IPC::ResponseBuilder rb{ctx, 2 + Common::DivCeil(layout.cmif_raw_data_size, sizeof(u32)), layout.copy_handle_count, layout.move_handle_count + layout.domain_interface_count}; IPC::ResponseBuilder rb{ctx, 2 + Common::DivCeil(layout.cmif_raw_data_size, sizeof(u32)), layout.copy_handle_count, layout.move_handle_count + layout.domain_interface_count};
rb.Push(res); rb.Push(res);
// Write out arguments. // Write out arguments.
WriteOutArgument<Domain, MethodArguments>(call_arguments, reinterpret_cast<u8*>(ctx.CommandBuffer() + rb.GetCurrentOffset()), ctx, buffers); WriteOutArgument<MethodArguments>(is_domain, call_arguments, reinterpret_cast<u8*>(ctx.CommandBuffer() + rb.GetCurrentOffset()), ctx, buffers);
} }
// clang-format on // clang-format on