Merge pull request #11385 from liamwhite/acceptcancel
internal_network: cancel pending socket operations on application process termination
This commit is contained in:
		| @@ -406,6 +406,7 @@ struct System::Impl { | |||||||
|             gpu_core->NotifyShutdown(); |             gpu_core->NotifyShutdown(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         Network::CancelPendingSocketOperations(); | ||||||
|         kernel.SuspendApplication(true); |         kernel.SuspendApplication(true); | ||||||
|         if (services) { |         if (services) { | ||||||
|             services->KillNVNFlinger(); |             services->KillNVNFlinger(); | ||||||
| @@ -427,6 +428,7 @@ struct System::Impl { | |||||||
|         debugger.reset(); |         debugger.reset(); | ||||||
|         kernel.Shutdown(); |         kernel.Shutdown(); | ||||||
|         memory.Reset(); |         memory.Reset(); | ||||||
|  |         Network::RestartSocketOperations(); | ||||||
|  |  | ||||||
|         if (auto room_member = room_network.GetRoomMember().lock()) { |         if (auto room_member = room_network.GetRoomMember().lock()) { | ||||||
|             Network::GameInfo game_info{}; |             Network::GameInfo game_info{}; | ||||||
|   | |||||||
| @@ -48,15 +48,32 @@ enum class CallType { | |||||||
|  |  | ||||||
| using socklen_t = int; | using socklen_t = int; | ||||||
|  |  | ||||||
|  | SOCKET interrupt_socket = static_cast<SOCKET>(-1); | ||||||
|  |  | ||||||
|  | void InterruptSocketOperations() { | ||||||
|  |     closesocket(interrupt_socket); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void AcknowledgeInterrupt() { | ||||||
|  |     interrupt_socket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); | ||||||
|  | } | ||||||
|  |  | ||||||
| void Initialize() { | void Initialize() { | ||||||
|     WSADATA wsa_data; |     WSADATA wsa_data; | ||||||
|     (void)WSAStartup(MAKEWORD(2, 2), &wsa_data); |     (void)WSAStartup(MAKEWORD(2, 2), &wsa_data); | ||||||
|  |  | ||||||
|  |     AcknowledgeInterrupt(); | ||||||
| } | } | ||||||
|  |  | ||||||
| void Finalize() { | void Finalize() { | ||||||
|  |     InterruptSocketOperations(); | ||||||
|     WSACleanup(); |     WSACleanup(); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | SOCKET GetInterruptSocket() { | ||||||
|  |     return interrupt_socket; | ||||||
|  | } | ||||||
|  |  | ||||||
| sockaddr TranslateFromSockAddrIn(SockAddrIn input) { | sockaddr TranslateFromSockAddrIn(SockAddrIn input) { | ||||||
|     sockaddr_in result; |     sockaddr_in result; | ||||||
|  |  | ||||||
| @@ -157,9 +174,42 @@ constexpr int SD_RECEIVE = SHUT_RD; | |||||||
| constexpr int SD_SEND = SHUT_WR; | constexpr int SD_SEND = SHUT_WR; | ||||||
| constexpr int SD_BOTH = SHUT_RDWR; | constexpr int SD_BOTH = SHUT_RDWR; | ||||||
|  |  | ||||||
| void Initialize() {} | int interrupt_pipe_fd[2] = {-1, -1}; | ||||||
|  |  | ||||||
| void Finalize() {} | void Initialize() { | ||||||
|  |     if (pipe(interrupt_pipe_fd) != 0) { | ||||||
|  |         LOG_ERROR(Network, "Failed to create interrupt pipe!"); | ||||||
|  |     } | ||||||
|  |     int flags = fcntl(interrupt_pipe_fd[0], F_GETFL); | ||||||
|  |     ASSERT_MSG(fcntl(interrupt_pipe_fd[0], F_SETFL, flags | O_NONBLOCK) == 0, | ||||||
|  |                "Failed to set nonblocking state for interrupt pipe"); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void Finalize() { | ||||||
|  |     if (interrupt_pipe_fd[0] >= 0) { | ||||||
|  |         close(interrupt_pipe_fd[0]); | ||||||
|  |     } | ||||||
|  |     if (interrupt_pipe_fd[1] >= 0) { | ||||||
|  |         close(interrupt_pipe_fd[1]); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void InterruptSocketOperations() { | ||||||
|  |     u8 value = 0; | ||||||
|  |     ASSERT(write(interrupt_pipe_fd[1], &value, sizeof(value)) == 1); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void AcknowledgeInterrupt() { | ||||||
|  |     u8 value = 0; | ||||||
|  |     ssize_t ret = read(interrupt_pipe_fd[0], &value, sizeof(value)); | ||||||
|  |     if (ret != 1 && errno != EAGAIN && errno != EWOULDBLOCK) { | ||||||
|  |         LOG_ERROR(Network, "Failed to acknowledge interrupt on shutdown"); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | SOCKET GetInterruptSocket() { | ||||||
|  |     return interrupt_pipe_fd[0]; | ||||||
|  | } | ||||||
|  |  | ||||||
| sockaddr TranslateFromSockAddrIn(SockAddrIn input) { | sockaddr TranslateFromSockAddrIn(SockAddrIn input) { | ||||||
|     sockaddr_in result; |     sockaddr_in result; | ||||||
| @@ -490,6 +540,14 @@ NetworkInstance::~NetworkInstance() { | |||||||
|     Finalize(); |     Finalize(); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void CancelPendingSocketOperations() { | ||||||
|  |     InterruptSocketOperations(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void RestartSocketOperations() { | ||||||
|  |     AcknowledgeInterrupt(); | ||||||
|  | } | ||||||
|  |  | ||||||
| std::optional<IPv4Address> GetHostIPv4Address() { | std::optional<IPv4Address> GetHostIPv4Address() { | ||||||
|     const auto network_interface = Network::GetSelectedNetworkInterface(); |     const auto network_interface = Network::GetSelectedNetworkInterface(); | ||||||
|     if (!network_interface.has_value()) { |     if (!network_interface.has_value()) { | ||||||
| @@ -560,7 +618,14 @@ std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { | |||||||
|         return result; |         return result; | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
|     const int result = WSAPoll(host_pollfds.data(), static_cast<ULONG>(num), timeout); |     host_pollfds.push_back(WSAPOLLFD{ | ||||||
|  |         .fd = GetInterruptSocket(), | ||||||
|  |         .events = POLLIN, | ||||||
|  |         .revents = 0, | ||||||
|  |     }); | ||||||
|  |  | ||||||
|  |     const int result = | ||||||
|  |         WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), timeout); | ||||||
|     if (result == 0) { |     if (result == 0) { | ||||||
|         ASSERT(std::all_of(host_pollfds.begin(), host_pollfds.end(), |         ASSERT(std::all_of(host_pollfds.begin(), host_pollfds.end(), | ||||||
|                            [](WSAPOLLFD fd) { return fd.revents == 0; })); |                            [](WSAPOLLFD fd) { return fd.revents == 0; })); | ||||||
| @@ -627,6 +692,24 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { | |||||||
| std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { | std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { | ||||||
|     sockaddr_in addr; |     sockaddr_in addr; | ||||||
|     socklen_t addrlen = sizeof(addr); |     socklen_t addrlen = sizeof(addr); | ||||||
|  |  | ||||||
|  |     std::vector<WSAPOLLFD> host_pollfds{ | ||||||
|  |         WSAPOLLFD{fd, POLLIN, 0}, | ||||||
|  |         WSAPOLLFD{GetInterruptSocket(), POLLIN, 0}, | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     while (true) { | ||||||
|  |         const int pollres = | ||||||
|  |             WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), -1); | ||||||
|  |         if (host_pollfds[1].revents != 0) { | ||||||
|  |             // Interrupt signaled before a client could be accepted, break | ||||||
|  |             return {AcceptResult{}, Errno::AGAIN}; | ||||||
|  |         } | ||||||
|  |         if (pollres > 0) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); |     const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); | ||||||
|  |  | ||||||
|     if (new_socket == INVALID_SOCKET) { |     if (new_socket == INVALID_SOCKET) { | ||||||
|   | |||||||
| @@ -96,6 +96,9 @@ public: | |||||||
|     ~NetworkInstance(); |     ~NetworkInstance(); | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | void CancelPendingSocketOperations(); | ||||||
|  | void RestartSocketOperations(); | ||||||
|  |  | ||||||
| #ifdef _WIN32 | #ifdef _WIN32 | ||||||
| constexpr IPv4Address TranslateIPv4(in_addr addr) { | constexpr IPv4Address TranslateIPv4(in_addr addr) { | ||||||
|     auto& bytes = addr.S_un.S_un_b; |     auto& bytes = addr.S_un.S_un_b; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user