From 4d684174e04119c57c77bd27aec9d9080cae5c2b Mon Sep 17 00:00:00 2001 From: PabloMK7 Date: Thu, 27 Oct 2022 15:05:49 +0200 Subject: [PATCH] Fix socket poll and handling in windows (#6166) * Fix socket poll and handling in windows * Fix clang * Add guest timing adjust * Use platform independent time fetch * Use proper type in time_point * Fix ambiguous function call * Do suggestions * Take cpu_clock_scale into account in tick adjust --- src/core/core_timing.cpp | 16 +++ src/core/core_timing.h | 8 ++ src/core/hle/service/soc_u.cpp | 231 +++++++++++++++++++++++++-------- src/core/hle/service/soc_u.h | 24 ++++ 4 files changed, 227 insertions(+), 52 deletions(-) diff --git a/src/core/core_timing.cpp b/src/core/core_timing.cpp index fbe1e9e31..bf9adebf5 100644 --- a/src/core/core_timing.cpp +++ b/src/core/core_timing.cpp @@ -174,6 +174,22 @@ void Timing::Timer::MoveEvents() { } } +u32 Timing::Timer::StartAdjust() { + ASSERT((adjust_value_curr_handle & 1) == 0); // Should always be even + adjust_value_last = std::chrono::steady_clock::now(); + return ++adjust_value_curr_handle; +} + +void Timing::Timer::EndAdjust(u32 start_adjust_handle) { + std::chrono::time_point new_timer = std::chrono::steady_clock::now(); + ASSERT(new_timer >= adjust_value_last && start_adjust_handle == adjust_value_curr_handle); + AddTicks(nsToCycles(static_cast( + std::chrono::duration_cast(new_timer - adjust_value_last) + .count() / + cpu_clock_scale))); + ++adjust_value_curr_handle; +} + s64 Timing::Timer::GetMaxSliceLength() const { const auto& next_event = event_queue.begin(); if (next_event != event_queue.end()) { diff --git a/src/core/core_timing.h b/src/core/core_timing.h index 611122211..cb54c316d 100644 --- a/src/core/core_timing.h +++ b/src/core/core_timing.h @@ -203,6 +203,11 @@ public: void MoveEvents(); + // Use these two functions to adjust the guest system tick on host blocking operations, so + // that the guest can tell how much time passed during the host call. + u32 StartAdjust(); + void EndAdjust(u32 start_adjust_handle); + private: friend class Timing; // The queue is a min-heap using std::make_heap/push_heap/pop_heap. @@ -227,6 +232,9 @@ public: s64 downcount = MAX_SLICE_LENGTH; s64 executed_ticks = 0; u64 idled_cycles = 0; + + std::chrono::time_point adjust_value_last; + u32 adjust_value_curr_handle = 0; // Stores a scaling for the internal clockspeed. Changing this number results in // under/overclocking the guest cpu double cpu_clock_scale = 1.0; diff --git a/src/core/hle/service/soc_u.cpp b/src/core/hle/service/soc_u.cpp index 0353a9668..ef6f7f727 100644 --- a/src/core/hle/service/soc_u.cpp +++ b/src/core/hle/service/soc_u.cpp @@ -221,19 +221,25 @@ struct CTRPollFD { /// Translates the resulting events of a Poll operation from 3ds specific to platform /// specific - static u32 TranslateToPlatform(Events input_event) { + static u32 TranslateToPlatform(Events input_event, bool isOutput) { +#if _WIN32 + constexpr bool isWin = true; +#else + constexpr bool isWin = false; +#endif + u32 ret = 0; if (input_event.pollin) ret |= POLLIN; - if (input_event.pollpri) + if (input_event.pollpri && !isWin) ret |= POLLPRI; - if (input_event.pollhup) + if (input_event.pollhup && (!isWin || isOutput)) ret |= POLLHUP; - if (input_event.pollerr) + if (input_event.pollerr && (!isWin || isOutput)) ret |= POLLERR; if (input_event.pollout) ret |= POLLOUT; - if (input_event.pollnval) + if (input_event.pollnval && (isWin && isOutput)) ret |= POLLNVAL; return ret; } @@ -242,20 +248,26 @@ struct CTRPollFD { Events revents; ///< Events received (output) /// Converts a platform-specific pollfd to a 3ds specific structure - static CTRPollFD FromPlatform(pollfd const& fd) { + static CTRPollFD FromPlatform(SOC::SOC_U& socu, pollfd const& fd) { CTRPollFD result; result.events.hex = Events::TranslateTo3DS(fd.events).hex; result.revents.hex = Events::TranslateTo3DS(fd.revents).hex; - result.fd = static_cast(fd.fd); + for (auto iter = socu.open_sockets.begin(); iter != socu.open_sockets.end(); ++iter) { + if (iter->second.socket_fd == fd.fd) { + result.fd = iter->first; + break; + } + } return result; } /// Converts a 3ds specific pollfd to a platform-specific structure - static pollfd ToPlatform(CTRPollFD const& fd) { + static pollfd ToPlatform(SOC::SOC_U& socu, CTRPollFD const& fd) { pollfd result; - result.events = Events::TranslateToPlatform(fd.events); - result.revents = Events::TranslateToPlatform(fd.revents); - result.fd = fd.fd; + result.events = Events::TranslateToPlatform(fd.events, false); + result.revents = Events::TranslateToPlatform(fd.revents, true); + auto iter = socu.open_sockets.find(fd.fd); + result.fd = (iter != socu.open_sockets.end()) ? iter->second.socket_fd : 0; return result; } }; @@ -351,6 +363,14 @@ struct CTRAddrInfo { static_assert(sizeof(CTRAddrInfo) == 0x130, "Size of CTRAddrInfo is not correct"); +void SOC_U::PreTimerAdjust() { + timer_adjust_handle = Core::System::GetInstance().GetRunningCore().GetTimer().StartAdjust(); +} + +void SOC_U::PostTimerAdjust() { + Core::System::GetInstance().GetRunningCore().GetTimer().EndAdjust(timer_adjust_handle); +} + void SOC_U::CleanupSockets() { for (auto sock : open_sockets) closesocket(sock.second.socket_fd); @@ -385,21 +405,28 @@ void SOC_U::Socket(Kernel::HLERequestContext& ctx) { return; } - u32 ret = static_cast(::socket(domain, type, protocol)); + u64 ret = static_cast(::socket(domain, type, protocol)); + u32 socketHandle = GetNextSocketID(); - if ((s32)ret != SOCKET_ERROR_VALUE) - open_sockets[ret] = {ret, true}; + if ((s64)ret != SOCKET_ERROR_VALUE) + open_sockets[socketHandle] = {static_cast(ret), true}; - if ((s32)ret == SOCKET_ERROR_VALUE) + if ((s64)ret == SOCKET_ERROR_VALUE) ret = TranslateError(GET_ERRNO); rb.Push(RESULT_SUCCESS); - rb.Push(ret); + rb.Push(socketHandle); } void SOC_U::Bind(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x05, 2, 4); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } u32 len = rp.Pop(); rp.PopPID(); auto sock_addr_buf = rp.PopStaticBuffer(); @@ -409,7 +436,7 @@ void SOC_U::Bind(Kernel::HLERequestContext& ctx) { sockaddr sock_addr = CTRSockAddr::ToPlatform(ctr_sock_addr); - s32 ret = ::bind(socket_handle, &sock_addr, std::max(sizeof(sock_addr), len)); + s32 ret = ::bind(fd_info->second.socket_fd, &sock_addr, std::max(sizeof(sock_addr), len)); if (ret != 0) ret = TranslateError(GET_ERRNO); @@ -422,6 +449,12 @@ void SOC_U::Bind(Kernel::HLERequestContext& ctx) { void SOC_U::Fcntl(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x13, 3, 2); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } u32 ctr_cmd = rp.Pop(); u32 ctr_arg = rp.Pop(); rp.PopPID(); @@ -436,11 +469,10 @@ void SOC_U::Fcntl(Kernel::HLERequestContext& ctx) { if (ctr_cmd == 3) { // F_GETFL #ifdef _WIN32 posix_ret = 0; - auto iter = open_sockets.find(socket_handle); - if (iter != open_sockets.end() && iter->second.blocking == false) + if (fd_info->second.blocking == false) posix_ret |= 4; // O_NONBLOCK #else - int ret = ::fcntl(socket_handle, F_GETFL, 0); + int ret = ::fcntl(fd_info->second.socket_fd, F_GETFL, 0); if (ret == SOCKET_ERROR_VALUE) { posix_ret = TranslateError(GET_ERRNO); return; @@ -452,7 +484,7 @@ void SOC_U::Fcntl(Kernel::HLERequestContext& ctx) { } else if (ctr_cmd == 4) { // F_SETFL #ifdef _WIN32 unsigned long tmp = (ctr_arg & 4 /* O_NONBLOCK */) ? 1 : 0; - int ret = ioctlsocket(socket_handle, FIONBIO, &tmp); + int ret = ioctlsocket(fd_info->second.socket_fd, FIONBIO, &tmp); if (ret == SOCKET_ERROR_VALUE) { posix_ret = TranslateError(GET_ERRNO); return; @@ -461,7 +493,7 @@ void SOC_U::Fcntl(Kernel::HLERequestContext& ctx) { if (iter != open_sockets.end()) iter->second.blocking = (tmp == 0); #else - int flags = ::fcntl(socket_handle, F_GETFL, 0); + int flags = ::fcntl(fd_info->second.socket_fd, F_GETFL, 0); if (flags == SOCKET_ERROR_VALUE) { posix_ret = TranslateError(GET_ERRNO); return; @@ -471,7 +503,7 @@ void SOC_U::Fcntl(Kernel::HLERequestContext& ctx) { if (ctr_arg & 4) // O_NONBLOCK flags |= O_NONBLOCK; - int ret = ::fcntl(socket_handle, F_SETFL, flags); + int ret = ::fcntl(fd_info->second.socket_fd, F_SETFL, flags); if (ret == SOCKET_ERROR_VALUE) { posix_ret = TranslateError(GET_ERRNO); return; @@ -487,10 +519,16 @@ void SOC_U::Fcntl(Kernel::HLERequestContext& ctx) { void SOC_U::Listen(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x03, 2, 2); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } u32 backlog = rp.Pop(); rp.PopPID(); - s32 ret = ::listen(socket_handle, backlog); + s32 ret = ::listen(fd_info->second.socket_fd, backlog); if (ret != 0) ret = TranslateError(GET_ERRNO); @@ -505,11 +543,19 @@ void SOC_U::Accept(Kernel::HLERequestContext& ctx) { // performing nonblocking operations and spinlock until the data is available IPC::RequestParser rp(ctx, 0x04, 2, 2); const auto socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } [[maybe_unused]] const auto max_addr_len = static_cast(rp.Pop()); rp.PopPID(); sockaddr addr; socklen_t addr_len = sizeof(addr); - u32 ret = static_cast(::accept(socket_handle, &addr, &addr_len)); + PreTimerAdjust(); + u32 ret = static_cast(::accept(fd_info->second.socket_fd, &addr, &addr_len)); + PostTimerAdjust(); if (static_cast(ret) != SOCKET_ERROR_VALUE) { open_sockets[ret] = {ret, true}; @@ -552,12 +598,21 @@ void SOC_U::GetHostId(Kernel::HLERequestContext& ctx) { void SOC_U::Close(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x0B, 1, 2); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } rp.PopPID(); s32 ret = 0; - open_sockets.erase(socket_handle); - ret = closesocket(socket_handle); + PreTimerAdjust(); + ret = closesocket(fd_info->second.socket_fd); + PostTimerAdjust(); + + open_sockets.erase(socket_handle); if (ret != 0) ret = TranslateError(GET_ERRNO); @@ -570,6 +625,12 @@ void SOC_U::Close(Kernel::HLERequestContext& ctx) { void SOC_U::SendTo(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x0A, 4, 6); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } u32 len = rp.Pop(); u32 flags = rp.Pop(); u32 addr_len = rp.Pop(); @@ -578,16 +639,18 @@ void SOC_U::SendTo(Kernel::HLERequestContext& ctx) { auto dest_addr_buff = rp.PopStaticBuffer(); s32 ret = -1; + PreTimerAdjust(); if (addr_len > 0) { CTRSockAddr ctr_dest_addr; std::memcpy(&ctr_dest_addr, dest_addr_buff.data(), sizeof(ctr_dest_addr)); sockaddr dest_addr = CTRSockAddr::ToPlatform(ctr_dest_addr); - ret = ::sendto(socket_handle, reinterpret_cast(input_buff.data()), len, flags, - &dest_addr, sizeof(dest_addr)); + ret = ::sendto(fd_info->second.socket_fd, reinterpret_cast(input_buff.data()), + len, flags, &dest_addr, sizeof(dest_addr)); } else { - ret = ::sendto(socket_handle, reinterpret_cast(input_buff.data()), len, flags, - nullptr, 0); + ret = ::sendto(fd_info->second.socket_fd, reinterpret_cast(input_buff.data()), + len, flags, nullptr, 0); } + PostTimerAdjust(); if (ret == SOCKET_ERROR_VALUE) ret = TranslateError(GET_ERRNO); @@ -600,6 +663,12 @@ void SOC_U::SendTo(Kernel::HLERequestContext& ctx) { void SOC_U::RecvFromOther(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x7, 4, 4); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } u32 len = rp.Pop(); u32 flags = rp.Pop(); u32 addr_len = rp.Pop(); @@ -613,19 +682,20 @@ void SOC_U::RecvFromOther(Kernel::HLERequestContext& ctx) { socklen_t src_addr_len = sizeof(src_addr); s32 ret = -1; + PreTimerAdjust(); if (addr_len > 0) { - ret = ::recvfrom(socket_handle, reinterpret_cast(output_buff.data()), len, flags, - &src_addr, &src_addr_len); + ret = ::recvfrom(fd_info->second.socket_fd, reinterpret_cast(output_buff.data()), + len, flags, &src_addr, &src_addr_len); if (ret >= 0 && src_addr_len > 0) { ctr_src_addr = CTRSockAddr::FromPlatform(src_addr); std::memcpy(addr_buff.data(), &ctr_src_addr, sizeof(ctr_src_addr)); } } else { - ret = ::recvfrom(socket_handle, reinterpret_cast(output_buff.data()), len, flags, - NULL, 0); + ret = ::recvfrom(fd_info->second.socket_fd, reinterpret_cast(output_buff.data()), + len, flags, NULL, 0); addr_buff.resize(0); } - + PostTimerAdjust(); if (ret == SOCKET_ERROR_VALUE) { ret = TranslateError(GET_ERRNO); } else { @@ -645,6 +715,12 @@ void SOC_U::RecvFrom(Kernel::HLERequestContext& ctx) { // performing nonblocking operations and spinlock until the data is available IPC::RequestParser rp(ctx, 0x08, 4, 2); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } u32 len = rp.Pop(); u32 flags = rp.Pop(); u32 addr_len = rp.Pop(); @@ -657,19 +733,21 @@ void SOC_U::RecvFrom(Kernel::HLERequestContext& ctx) { socklen_t src_addr_len = sizeof(src_addr); s32 ret = -1; + PreTimerAdjust(); if (addr_len > 0) { // Only get src adr if input adr available - ret = ::recvfrom(socket_handle, reinterpret_cast(output_buff.data()), len, flags, - &src_addr, &src_addr_len); + ret = ::recvfrom(fd_info->second.socket_fd, reinterpret_cast(output_buff.data()), + len, flags, &src_addr, &src_addr_len); if (ret >= 0 && src_addr_len > 0) { ctr_src_addr = CTRSockAddr::FromPlatform(src_addr); std::memcpy(addr_buff.data(), &ctr_src_addr, sizeof(ctr_src_addr)); } } else { - ret = ::recvfrom(socket_handle, reinterpret_cast(output_buff.data()), len, flags, - NULL, 0); + ret = ::recvfrom(fd_info->second.socket_fd, reinterpret_cast(output_buff.data()), + len, flags, NULL, 0); addr_buff.resize(0); } + PostTimerAdjust(); s32 total_received = ret; if (ret == SOCKET_ERROR_VALUE) { @@ -700,21 +778,32 @@ void SOC_U::Poll(Kernel::HLERequestContext& ctx) { // The 3ds_pollfd and the pollfd structures may be different (Windows/Linux have different // sizes) - // so we have to copy the data + // so we have to copy the data in order std::vector platform_pollfd(nfds); - std::transform(ctr_fds.begin(), ctr_fds.end(), platform_pollfd.begin(), CTRPollFD::ToPlatform); + for (u32 i = 0; i < nfds; i++) { + platform_pollfd[i] = CTRPollFD::ToPlatform(*this, ctr_fds[i]); + } + PreTimerAdjust(); s32 ret = ::poll(platform_pollfd.data(), nfds, timeout); + PostTimerAdjust(); - // Now update the output pollfd structure - std::transform(platform_pollfd.begin(), platform_pollfd.end(), ctr_fds.begin(), - CTRPollFD::FromPlatform); + // Now update the output 3ds_pollfd structure + for (u32 i = 0; i < nfds; i++) { + ctr_fds[i] = CTRPollFD::FromPlatform(*this, platform_pollfd[i]); + } std::vector output_fds(nfds * sizeof(CTRPollFD)); std::memcpy(output_fds.data(), ctr_fds.data(), nfds * sizeof(CTRPollFD)); - if (ret == SOCKET_ERROR_VALUE) + if (ret == SOCKET_ERROR_VALUE) { + int err = GET_ERRNO; + LOG_ERROR(Service_SOC, "Socket error: {}", err); + ret = TranslateError(GET_ERRNO); + } + + size_t test = platform_pollfd.size(); IPC::RequestBuilder rb = rp.MakeBuilder(2, 2); rb.Push(RESULT_SUCCESS); @@ -725,12 +814,18 @@ void SOC_U::Poll(Kernel::HLERequestContext& ctx) { void SOC_U::GetSockName(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x17, 2, 2); const auto socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } [[maybe_unused]] const auto max_addr_len = rp.Pop(); rp.PopPID(); sockaddr dest_addr; socklen_t dest_addr_len = sizeof(dest_addr); - s32 ret = ::getsockname(socket_handle, &dest_addr, &dest_addr_len); + s32 ret = ::getsockname(fd_info->second.socket_fd, &dest_addr, &dest_addr_len); CTRSockAddr ctr_dest_addr = CTRSockAddr::FromPlatform(dest_addr); std::vector dest_addr_buff(sizeof(ctr_dest_addr)); @@ -748,10 +843,16 @@ void SOC_U::GetSockName(Kernel::HLERequestContext& ctx) { void SOC_U::Shutdown(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x0C, 2, 2); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } s32 how = rp.Pop(); rp.PopPID(); - s32 ret = ::shutdown(socket_handle, how); + s32 ret = ::shutdown(fd_info->second.socket_fd, how); if (ret != 0) ret = TranslateError(GET_ERRNO); IPC::RequestBuilder rb = rp.MakeBuilder(2, 0); @@ -762,12 +863,18 @@ void SOC_U::Shutdown(Kernel::HLERequestContext& ctx) { void SOC_U::GetPeerName(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x18, 2, 2); const auto socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } [[maybe_unused]] const auto max_addr_len = rp.Pop(); rp.PopPID(); sockaddr dest_addr; socklen_t dest_addr_len = sizeof(dest_addr); - const int ret = ::getpeername(socket_handle, &dest_addr, &dest_addr_len); + const int ret = ::getpeername(fd_info->second.socket_fd, &dest_addr, &dest_addr_len); CTRSockAddr ctr_dest_addr = CTRSockAddr::FromPlatform(dest_addr); std::vector dest_addr_buff(sizeof(ctr_dest_addr)); @@ -790,6 +897,12 @@ void SOC_U::Connect(Kernel::HLERequestContext& ctx) { // performing nonblocking operations and spinlock until the data is available IPC::RequestParser rp(ctx, 0x06, 2, 4); const auto socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } [[maybe_unused]] const auto input_addr_len = rp.Pop(); rp.PopPID(); auto input_addr_buf = rp.PopStaticBuffer(); @@ -798,7 +911,9 @@ void SOC_U::Connect(Kernel::HLERequestContext& ctx) { std::memcpy(&ctr_input_addr, input_addr_buf.data(), sizeof(ctr_input_addr)); sockaddr input_addr = CTRSockAddr::ToPlatform(ctr_input_addr); - s32 ret = ::connect(socket_handle, &input_addr, sizeof(input_addr)); + PreTimerAdjust(); + s32 ret = ::connect(fd_info->second.socket_fd, &input_addr, sizeof(input_addr)); + PostTimerAdjust(); if (ret != 0) ret = TranslateError(GET_ERRNO); @@ -830,6 +945,12 @@ void SOC_U::ShutdownSockets(Kernel::HLERequestContext& ctx) { void SOC_U::GetSockOpt(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x11, 4, 2); u32 socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } u32 level = rp.Pop(); s32 optname = rp.Pop(); socklen_t optlen = static_cast(rp.Pop()); @@ -847,7 +968,7 @@ void SOC_U::GetSockOpt(Kernel::HLERequestContext& ctx) { #endif } else { char* optval_data = reinterpret_cast(optval.data()); - err = ::getsockopt(socket_handle, level, optname, optval_data, &optlen); + err = ::getsockopt(fd_info->second.socket_fd, level, optname, optval_data, &optlen); if (err == SOCKET_ERROR_VALUE) { err = TranslateError(GET_ERRNO); } @@ -863,6 +984,12 @@ void SOC_U::GetSockOpt(Kernel::HLERequestContext& ctx) { void SOC_U::SetSockOpt(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx, 0x12, 4, 4); const auto socket_handle = rp.Pop(); + auto fd_info = open_sockets.find(socket_handle); + if (fd_info == open_sockets.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ERR_INVALID_HANDLE); + return; + } const auto level = rp.Pop(); const auto optname = rp.Pop(); [[maybe_unused]] const auto optlen = static_cast(rp.Pop()); @@ -879,7 +1006,7 @@ void SOC_U::SetSockOpt(Kernel::HLERequestContext& ctx) { #endif } else { const char* optval_data = reinterpret_cast(optval.data()); - err = static_cast(::setsockopt(socket_handle, level, optname, optval_data, + err = static_cast(::setsockopt(fd_info->second.socket_fd, level, optname, optval_data, static_cast(optval.size()))); if (err == SOCKET_ERROR_VALUE) { err = TranslateError(GET_ERRNO); diff --git a/src/core/hle/service/soc_u.h b/src/core/hle/service/soc_u.h index 595a984f2..c2c57e4a2 100644 --- a/src/core/hle/service/soc_u.h +++ b/src/core/hle/service/soc_u.h @@ -6,6 +6,7 @@ #include #include +#include "core/hle/result.h" #include "core/hle/service/service.h" namespace Core { @@ -16,7 +17,13 @@ namespace Service::SOC { /// Holds information about a particular socket struct SocketHolder { +#ifdef _WIN32 + using SOCKET = unsigned long long; + SOCKET socket_fd; ///< The socket descriptor +#else u32 socket_fd; ///< The socket descriptor +#endif // _WIN32 + bool blocking; ///< Whether the socket is blocking or not, it is only read on Windows. private: @@ -34,6 +41,10 @@ public: ~SOC_U(); private: + static constexpr ResultCode ERR_INVALID_HANDLE = + ResultCode(ErrorDescription::InvalidHandle, ErrorModule::SOC, ErrorSummary::InvalidArgument, + ErrorLevel::Permanent); + void Socket(Kernel::HLERequestContext& ctx); void Bind(Kernel::HLERequestContext& ctx); void Fcntl(Kernel::HLERequestContext& ctx); @@ -59,16 +70,29 @@ private: void GetAddrInfoImpl(Kernel::HLERequestContext& ctx); void GetNameInfoImpl(Kernel::HLERequestContext& ctx); + // Socked ids + u32 next_socket_id = 3; + u32 GetNextSocketID() { + return next_socket_id++; + } + + // System timer adjust + u32 timer_adjust_handle; + void PreTimerAdjust(); + void PostTimerAdjust(); + /// Close all open sockets void CleanupSockets(); /// Holds info about the currently open sockets + friend struct CTRPollFD; std::unordered_map open_sockets; template void serialize(Archive& ar, const unsigned int) { ar& boost::serialization::base_object(*this); ar& open_sockets; + ar& timer_adjust_handle; } friend class boost::serialization::access; };