@@ -63,6 +63,18 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF)
 | 
			
		||||
 | 
			
		||||
CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF)
 | 
			
		||||
 | 
			
		||||
set(DEFAULT_ENABLE_OPENSSL ON)
 | 
			
		||||
if (ANDROID OR WIN32 OR APPLE)
 | 
			
		||||
    # - Windows defaults to the Schannel backend.
 | 
			
		||||
    # - macOS defaults to the SecureTransport backend.
 | 
			
		||||
    # - Android currently has no SSL backend as the NDK doesn't include any SSL
 | 
			
		||||
    #   library; a proper 'native' backend would have to go through Java.
 | 
			
		||||
    # But you can force builds for those platforms to use OpenSSL if you have
 | 
			
		||||
    # your own copy of it.
 | 
			
		||||
    set(DEFAULT_ENABLE_OPENSSL OFF)
 | 
			
		||||
endif()
 | 
			
		||||
option(ENABLE_OPENSSL "Enable OpenSSL backend for ISslConnection" ${DEFAULT_ENABLE_OPENSSL})
 | 
			
		||||
 | 
			
		||||
# On Android, fetch and compile libcxx before doing anything else
 | 
			
		||||
if (ANDROID)
 | 
			
		||||
    set(CMAKE_SKIP_INSTALL_RULES ON)
 | 
			
		||||
@@ -322,6 +334,10 @@ if (MINGW)
 | 
			
		||||
    find_library(MSWSOCK_LIBRARY mswsock REQUIRED)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if(ENABLE_OPENSSL)
 | 
			
		||||
    find_package(OpenSSL 1.1.1 REQUIRED)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
# Please consider this as a stub
 | 
			
		||||
if(ENABLE_QT6 AND Qt6_LOCATION)
 | 
			
		||||
    list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}")
 | 
			
		||||
 
 | 
			
		||||
@@ -5,15 +5,19 @@
 | 
			
		||||
 | 
			
		||||
#include "common/common_types.h"
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
 | 
			
		||||
namespace Network {
 | 
			
		||||
 | 
			
		||||
/// Address families
 | 
			
		||||
enum class Domain : u8 {
 | 
			
		||||
    INET, ///< Address family for IPv4
 | 
			
		||||
    Unspecified, ///< Represents 0, used in getaddrinfo hints
 | 
			
		||||
    INET,        ///< Address family for IPv4
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Socket types
 | 
			
		||||
enum class Type {
 | 
			
		||||
    Unspecified, ///< Represents 0, used in getaddrinfo hints
 | 
			
		||||
    STREAM,
 | 
			
		||||
    DGRAM,
 | 
			
		||||
    RAW,
 | 
			
		||||
@@ -22,6 +26,7 @@ enum class Type {
 | 
			
		||||
 | 
			
		||||
/// Protocol values for sockets
 | 
			
		||||
enum class Protocol : u8 {
 | 
			
		||||
    Unspecified, ///< Represents 0, usable in various places
 | 
			
		||||
    ICMP,
 | 
			
		||||
    TCP,
 | 
			
		||||
    UDP,
 | 
			
		||||
@@ -48,4 +53,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2;
 | 
			
		||||
constexpr u32 FLAG_MSG_DONTWAIT = 0x80;
 | 
			
		||||
constexpr u32 FLAG_O_NONBLOCK = 0x800;
 | 
			
		||||
 | 
			
		||||
/// Cross-platform addrinfo structure
 | 
			
		||||
struct AddrInfo {
 | 
			
		||||
    Domain family;
 | 
			
		||||
    Type socket_type;
 | 
			
		||||
    Protocol protocol;
 | 
			
		||||
    SockAddrIn addr;
 | 
			
		||||
    std::optional<std::string> canon_name;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace Network
 | 
			
		||||
 
 | 
			
		||||
@@ -723,6 +723,7 @@ add_library(core STATIC
 | 
			
		||||
    hle/service/spl/spl_types.h
 | 
			
		||||
    hle/service/ssl/ssl.cpp
 | 
			
		||||
    hle/service/ssl/ssl.h
 | 
			
		||||
    hle/service/ssl/ssl_backend.h
 | 
			
		||||
    hle/service/time/clock_types.h
 | 
			
		||||
    hle/service/time/ephemeral_network_system_clock_context_writer.h
 | 
			
		||||
    hle/service/time/ephemeral_network_system_clock_core.h
 | 
			
		||||
@@ -864,6 +865,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64)
 | 
			
		||||
    target_link_libraries(core PRIVATE dynarmic::dynarmic)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if(ENABLE_OPENSSL)
 | 
			
		||||
    target_sources(core PRIVATE
 | 
			
		||||
        hle/service/ssl/ssl_backend_openssl.cpp)
 | 
			
		||||
    target_link_libraries(core PRIVATE OpenSSL::SSL)
 | 
			
		||||
elseif (APPLE)
 | 
			
		||||
    target_sources(core PRIVATE
 | 
			
		||||
        hle/service/ssl/ssl_backend_securetransport.cpp)
 | 
			
		||||
    target_link_libraries(core PRIVATE "-framework Security")
 | 
			
		||||
elseif (WIN32)
 | 
			
		||||
    target_sources(core PRIVATE
 | 
			
		||||
        hle/service/ssl/ssl_backend_schannel.cpp)
 | 
			
		||||
    target_link_libraries(core PRIVATE secur32)
 | 
			
		||||
else()
 | 
			
		||||
    target_sources(core PRIVATE
 | 
			
		||||
        hle/service/ssl/ssl_backend_none.cpp)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if (YUZU_USE_PRECOMPILED_HEADERS)
 | 
			
		||||
    target_precompile_headers(core PRIVATE precompiled_headers.h)
 | 
			
		||||
endif()
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,9 @@
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
#include "network/network.h"
 | 
			
		||||
 | 
			
		||||
using Common::Expected;
 | 
			
		||||
using Common::Unexpected;
 | 
			
		||||
 | 
			
		||||
namespace Service::Sockets {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
@@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) {
 | 
			
		||||
    const u32 level = rp.Pop<u32>();
 | 
			
		||||
    const auto optname = static_cast<OptName>(rp.Pop<u32>());
 | 
			
		||||
 | 
			
		||||
    LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname);
 | 
			
		||||
 | 
			
		||||
    std::vector<u8> optval(ctx.GetWriteBufferSize());
 | 
			
		||||
 | 
			
		||||
    LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname,
 | 
			
		||||
              optval.size());
 | 
			
		||||
 | 
			
		||||
    const Errno err = GetSockOptImpl(fd, level, optname, optval);
 | 
			
		||||
 | 
			
		||||
    ctx.WriteBuffer(optval);
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 5};
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.Push<s32>(-1);
 | 
			
		||||
    rb.PushEnum(Errno::NOTCONN);
 | 
			
		||||
    rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1);
 | 
			
		||||
    rb.PushEnum(err);
 | 
			
		||||
    rb.Push<u32>(static_cast<u32>(optval.size()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) {
 | 
			
		||||
    BuildErrnoResponse(ctx, CloseImpl(fd));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BSD::DuplicateSocket(HLERequestContext& ctx) {
 | 
			
		||||
    struct InputParameters {
 | 
			
		||||
        s32 fd;
 | 
			
		||||
        u64 reserved;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(InputParameters) == 0x10);
 | 
			
		||||
 | 
			
		||||
    struct OutputParameters {
 | 
			
		||||
        s32 ret;
 | 
			
		||||
        Errno bsd_errno;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(OutputParameters) == 0x8);
 | 
			
		||||
 | 
			
		||||
    IPC::RequestParser rp{ctx};
 | 
			
		||||
    auto input = rp.PopRaw<InputParameters>();
 | 
			
		||||
 | 
			
		||||
    Expected<s32, Errno> res = DuplicateSocketImpl(input.fd);
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 4};
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.PushRaw(OutputParameters{
 | 
			
		||||
        .ret = res.value_or(0),
 | 
			
		||||
        .bsd_errno = res ? Errno::SUCCESS : res.error(),
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BSD::EventFd(HLERequestContext& ctx) {
 | 
			
		||||
    IPC::RequestParser rp{ctx};
 | 
			
		||||
    const u64 initval = rp.Pop<u64>();
 | 
			
		||||
@@ -477,12 +508,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco
 | 
			
		||||
 | 
			
		||||
    auto room_member = room_network.GetRoomMember().lock();
 | 
			
		||||
    if (room_member && room_member->IsConnected()) {
 | 
			
		||||
        descriptor.socket = std::make_unique<Network::ProxySocket>(room_network);
 | 
			
		||||
        descriptor.socket = std::make_shared<Network::ProxySocket>(room_network);
 | 
			
		||||
    } else {
 | 
			
		||||
        descriptor.socket = std::make_unique<Network::Socket>();
 | 
			
		||||
        descriptor.socket = std::make_shared<Network::Socket>();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol));
 | 
			
		||||
    descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol));
 | 
			
		||||
    descriptor.is_connection_based = IsConnectionBased(type);
 | 
			
		||||
 | 
			
		||||
    return {fd, Errno::SUCCESS};
 | 
			
		||||
@@ -538,7 +569,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
 | 
			
		||||
    std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) {
 | 
			
		||||
        Network::PollFD result;
 | 
			
		||||
        result.socket = file_descriptors[pollfd.fd]->socket.get();
 | 
			
		||||
        result.events = TranslatePollEventsToHost(pollfd.events);
 | 
			
		||||
        result.events = Translate(pollfd.events);
 | 
			
		||||
        result.revents = Network::PollEvents{};
 | 
			
		||||
        return result;
 | 
			
		||||
    });
 | 
			
		||||
@@ -547,7 +578,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
 | 
			
		||||
 | 
			
		||||
    const size_t num = host_pollfds.size();
 | 
			
		||||
    for (size_t i = 0; i < num; ++i) {
 | 
			
		||||
        fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents);
 | 
			
		||||
        fds[i].revents = Translate(host_pollfds[i].revents);
 | 
			
		||||
    }
 | 
			
		||||
    std::memcpy(write_buffer.data(), fds.data(), length);
 | 
			
		||||
 | 
			
		||||
@@ -617,7 +648,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) {
 | 
			
		||||
    }
 | 
			
		||||
    const SockAddrIn guest_addrin = Translate(addr_in);
 | 
			
		||||
 | 
			
		||||
    ASSERT(write_buffer.size() == sizeof(guest_addrin));
 | 
			
		||||
    ASSERT(write_buffer.size() >= sizeof(guest_addrin));
 | 
			
		||||
    write_buffer.resize(sizeof(guest_addrin));
 | 
			
		||||
    std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
 | 
			
		||||
    return Translate(bsd_errno);
 | 
			
		||||
}
 | 
			
		||||
@@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) {
 | 
			
		||||
    }
 | 
			
		||||
    const SockAddrIn guest_addrin = Translate(addr_in);
 | 
			
		||||
 | 
			
		||||
    ASSERT(write_buffer.size() == sizeof(guest_addrin));
 | 
			
		||||
    ASSERT(write_buffer.size() >= sizeof(guest_addrin));
 | 
			
		||||
    write_buffer.resize(sizeof(guest_addrin));
 | 
			
		||||
    std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
 | 
			
		||||
    return Translate(bsd_errno);
 | 
			
		||||
}
 | 
			
		||||
@@ -671,13 +704,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
 | 
			
		||||
    UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET
 | 
			
		||||
 | 
			
		||||
Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) {
 | 
			
		||||
    if (!IsFileDescriptorValid(fd)) {
 | 
			
		||||
        return Errno::BADF;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (level != static_cast<u32>(SocketLevel::SOCKET)) {
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unknown getsockopt level");
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
 | 
			
		||||
 | 
			
		||||
    switch (optname) {
 | 
			
		||||
    case OptName::ERROR_: {
 | 
			
		||||
        auto [pending_err, getsockopt_err] = socket->GetPendingError();
 | 
			
		||||
        if (getsockopt_err == Network::Errno::SUCCESS) {
 | 
			
		||||
            Errno translated_pending_err = Translate(pending_err);
 | 
			
		||||
            ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
                optval.size() == sizeof(Errno), { return Errno::INVAL; },
 | 
			
		||||
                "Incorrect getsockopt option size");
 | 
			
		||||
            optval.resize(sizeof(Errno));
 | 
			
		||||
            memcpy(optval.data(), &translated_pending_err, sizeof(Errno));
 | 
			
		||||
        }
 | 
			
		||||
        return Translate(getsockopt_err);
 | 
			
		||||
    }
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
 | 
			
		||||
    if (!IsFileDescriptorValid(fd)) {
 | 
			
		||||
        return Errno::BADF;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (level != static_cast<u32>(SocketLevel::SOCKET)) {
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unknown setsockopt level");
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
 | 
			
		||||
 | 
			
		||||
    if (optname == OptName::LINGER) {
 | 
			
		||||
@@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con
 | 
			
		||||
        return Translate(socket->SetSndTimeo(value));
 | 
			
		||||
    case OptName::RCVTIMEO:
 | 
			
		||||
        return Translate(socket->SetRcvTimeo(value));
 | 
			
		||||
    case OptName::NOSIGPIPE:
 | 
			
		||||
        LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value);
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
@@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) {
 | 
			
		||||
    return bsd_errno;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) {
 | 
			
		||||
    if (!IsFileDescriptorValid(fd)) {
 | 
			
		||||
        return Unexpected(Errno::BADF);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const s32 new_fd = FindFreeFileDescriptorHandle();
 | 
			
		||||
    if (new_fd < 0) {
 | 
			
		||||
        LOG_ERROR(Service, "No more file descriptors available");
 | 
			
		||||
        return Unexpected(Errno::MFILE);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    file_descriptors[new_fd] = file_descriptors[fd];
 | 
			
		||||
    return new_fd;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) {
 | 
			
		||||
    if (!IsFileDescriptorValid(fd)) {
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
    return file_descriptors[fd]->socket;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
s32 BSD::FindFreeFileDescriptorHandle() noexcept {
 | 
			
		||||
    for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) {
 | 
			
		||||
        if (!file_descriptors[fd]) {
 | 
			
		||||
@@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name)
 | 
			
		||||
        {24, &BSD::Write, "Write"},
 | 
			
		||||
        {25, &BSD::Read, "Read"},
 | 
			
		||||
        {26, &BSD::Close, "Close"},
 | 
			
		||||
        {27, nullptr, "DuplicateSocket"},
 | 
			
		||||
        {27, &BSD::DuplicateSocket, "DuplicateSocket"},
 | 
			
		||||
        {28, nullptr, "GetResourceStatistics"},
 | 
			
		||||
        {29, nullptr, "RecvMMsg"},
 | 
			
		||||
        {30, nullptr, "SendMMsg"},
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "common/common_types.h"
 | 
			
		||||
#include "common/expected.h"
 | 
			
		||||
#include "common/socket_types.h"
 | 
			
		||||
#include "core/hle/service/service.h"
 | 
			
		||||
#include "core/hle/service/sockets/sockets.h"
 | 
			
		||||
@@ -29,12 +30,19 @@ public:
 | 
			
		||||
    explicit BSD(Core::System& system_, const char* name);
 | 
			
		||||
    ~BSD() override;
 | 
			
		||||
 | 
			
		||||
    // These methods are called from SSL; the first two are also called from
 | 
			
		||||
    // this class for the corresponding IPC methods.
 | 
			
		||||
    // On the real device, the SSL service makes IPC calls to this service.
 | 
			
		||||
    Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd);
 | 
			
		||||
    Errno CloseImpl(s32 fd);
 | 
			
		||||
    std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    /// Maximum number of file descriptors
 | 
			
		||||
    static constexpr size_t MAX_FD = 128;
 | 
			
		||||
 | 
			
		||||
    struct FileDescriptor {
 | 
			
		||||
        std::unique_ptr<Network::SocketBase> socket;
 | 
			
		||||
        std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
        s32 flags = 0;
 | 
			
		||||
        bool is_connection_based = false;
 | 
			
		||||
    };
 | 
			
		||||
@@ -138,6 +146,7 @@ private:
 | 
			
		||||
    void Write(HLERequestContext& ctx);
 | 
			
		||||
    void Read(HLERequestContext& ctx);
 | 
			
		||||
    void Close(HLERequestContext& ctx);
 | 
			
		||||
    void DuplicateSocket(HLERequestContext& ctx);
 | 
			
		||||
    void EventFd(HLERequestContext& ctx);
 | 
			
		||||
 | 
			
		||||
    template <typename Work>
 | 
			
		||||
@@ -153,6 +162,7 @@ private:
 | 
			
		||||
    Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer);
 | 
			
		||||
    Errno ListenImpl(s32 fd, s32 backlog);
 | 
			
		||||
    std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg);
 | 
			
		||||
    Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval);
 | 
			
		||||
    Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval);
 | 
			
		||||
    Errno ShutdownImpl(s32 fd, s32 how);
 | 
			
		||||
    std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message);
 | 
			
		||||
@@ -161,7 +171,6 @@ private:
 | 
			
		||||
    std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message);
 | 
			
		||||
    std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message,
 | 
			
		||||
                                     std::span<const u8> addr);
 | 
			
		||||
    Errno CloseImpl(s32 fd);
 | 
			
		||||
 | 
			
		||||
    s32 FindFreeFileDescriptorHandle() noexcept;
 | 
			
		||||
    bool IsFileDescriptorValid(s32 fd) const noexcept;
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,15 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ipc_helpers.h"
 | 
			
		||||
#include "core/hle/service/sockets/nsd.h"
 | 
			
		||||
 | 
			
		||||
#include "common/string_util.h"
 | 
			
		||||
 | 
			
		||||
namespace Service::Sockets {
 | 
			
		||||
 | 
			
		||||
constexpr Result ResultOverflow{ErrorModule::NSD, 6};
 | 
			
		||||
 | 
			
		||||
NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} {
 | 
			
		||||
    // clang-format off
 | 
			
		||||
    static const FunctionInfo functions[] = {
 | 
			
		||||
@@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
 | 
			
		||||
        {13, nullptr, "DeleteSettings"},
 | 
			
		||||
        {14, nullptr, "ImportSettings"},
 | 
			
		||||
        {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"},
 | 
			
		||||
        {20, nullptr, "Resolve"},
 | 
			
		||||
        {21, nullptr, "ResolveEx"},
 | 
			
		||||
        {20, &NSD::Resolve, "Resolve"},
 | 
			
		||||
        {21, &NSD::ResolveEx, "ResolveEx"},
 | 
			
		||||
        {30, nullptr, "GetNasServiceSetting"},
 | 
			
		||||
        {31, nullptr, "GetNasServiceSettingEx"},
 | 
			
		||||
        {40, nullptr, "GetNasRequestFqdn"},
 | 
			
		||||
@@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
 | 
			
		||||
    RegisterHandlers(functions);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
 | 
			
		||||
    // The real implementation makes various substitutions.
 | 
			
		||||
    // For now we just return the string as-is, which is good enough when not
 | 
			
		||||
    // connecting to real Nintendo servers.
 | 
			
		||||
    LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in);
 | 
			
		||||
    return fqdn_in;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
 | 
			
		||||
    const auto res = ResolveImpl(fqdn_in);
 | 
			
		||||
    if (res.Failed()) {
 | 
			
		||||
        return res.Code();
 | 
			
		||||
    }
 | 
			
		||||
    if (res->size() >= fqdn_out.size()) {
 | 
			
		||||
        return ResultOverflow;
 | 
			
		||||
    }
 | 
			
		||||
    std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1);
 | 
			
		||||
    return ResultSuccess;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NSD::Resolve(HLERequestContext& ctx) {
 | 
			
		||||
    const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
 | 
			
		||||
 | 
			
		||||
    std::array<char, 0x100> fqdn_out{};
 | 
			
		||||
    const Result res = ResolveCommon(fqdn_in, fqdn_out);
 | 
			
		||||
 | 
			
		||||
    ctx.WriteBuffer(fqdn_out);
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
    rb.Push(res);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NSD::ResolveEx(HLERequestContext& ctx) {
 | 
			
		||||
    const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
 | 
			
		||||
 | 
			
		||||
    std::array<char, 0x100> fqdn_out;
 | 
			
		||||
    const Result res = ResolveCommon(fqdn_in, fqdn_out);
 | 
			
		||||
 | 
			
		||||
    if (res.IsError()) {
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
        rb.Push(res);
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ctx.WriteBuffer(fqdn_out);
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 4};
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
NSD::~NSD() = default;
 | 
			
		||||
 | 
			
		||||
} // namespace Service::Sockets
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,10 @@ class NSD final : public ServiceFramework<NSD> {
 | 
			
		||||
public:
 | 
			
		||||
    explicit NSD(Core::System& system_, const char* name);
 | 
			
		||||
    ~NSD() override;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    void Resolve(HLERequestContext& ctx);
 | 
			
		||||
    void ResolveEx(HLERequestContext& ctx);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace Service::Sockets
 | 
			
		||||
 
 | 
			
		||||
@@ -10,27 +10,18 @@
 | 
			
		||||
#include "core/core.h"
 | 
			
		||||
#include "core/hle/service/ipc_helpers.h"
 | 
			
		||||
#include "core/hle/service/sockets/sfdnsres.h"
 | 
			
		||||
#include "core/hle/service/sockets/sockets.h"
 | 
			
		||||
#include "core/hle/service/sockets/sockets_translate.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/memory.h"
 | 
			
		||||
 | 
			
		||||
#ifdef _WIN32
 | 
			
		||||
#include <ws2tcpip.h>
 | 
			
		||||
#elif YUZU_UNIX
 | 
			
		||||
#include <arpa/inet.h>
 | 
			
		||||
#include <netdb.h>
 | 
			
		||||
#include <netinet/in.h>
 | 
			
		||||
#include <sys/socket.h>
 | 
			
		||||
#ifndef EAI_NODATA
 | 
			
		||||
#define EAI_NODATA EAI_NONAME
 | 
			
		||||
#endif
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace Service::Sockets {
 | 
			
		||||
 | 
			
		||||
SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} {
 | 
			
		||||
    static const FunctionInfo functions[] = {
 | 
			
		||||
        {0, nullptr, "SetDnsAddressesPrivateRequest"},
 | 
			
		||||
        {1, nullptr, "GetDnsAddressPrivateRequest"},
 | 
			
		||||
        {2, nullptr, "GetHostByNameRequest"},
 | 
			
		||||
        {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"},
 | 
			
		||||
        {3, nullptr, "GetHostByAddrRequest"},
 | 
			
		||||
        {4, nullptr, "GetHostStringErrorRequest"},
 | 
			
		||||
        {5, nullptr, "GetGaiStringErrorRequest"},
 | 
			
		||||
@@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"
 | 
			
		||||
        {7, nullptr, "GetNameInfoRequest"},
 | 
			
		||||
        {8, nullptr, "RequestCancelHandleRequest"},
 | 
			
		||||
        {9, nullptr, "CancelRequest"},
 | 
			
		||||
        {10, nullptr, "GetHostByNameRequestWithOptions"},
 | 
			
		||||
        {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"},
 | 
			
		||||
        {11, nullptr, "GetHostByAddrRequestWithOptions"},
 | 
			
		||||
        {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"},
 | 
			
		||||
        {13, nullptr, "GetNameInfoRequestWithOptions"},
 | 
			
		||||
        {14, nullptr, "ResolverSetOptionRequest"},
 | 
			
		||||
        {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"},
 | 
			
		||||
        {15, nullptr, "ResolverGetOptionRequest"},
 | 
			
		||||
    };
 | 
			
		||||
    RegisterHandlers(functions);
 | 
			
		||||
@@ -59,188 +50,285 @@ enum class NetDbError : s32 {
 | 
			
		||||
    NoData = 4,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static NetDbError AddrInfoErrorToNetDbError(s32 result) {
 | 
			
		||||
    // Best effort guess to map errors
 | 
			
		||||
static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) {
 | 
			
		||||
    // These combinations have been verified on console (but are not
 | 
			
		||||
    // exhaustive).
 | 
			
		||||
    switch (result) {
 | 
			
		||||
    case 0:
 | 
			
		||||
    case GetAddrInfoError::SUCCESS:
 | 
			
		||||
        return NetDbError::Success;
 | 
			
		||||
    case EAI_AGAIN:
 | 
			
		||||
    case GetAddrInfoError::AGAIN:
 | 
			
		||||
        return NetDbError::TryAgain;
 | 
			
		||||
    case EAI_NODATA:
 | 
			
		||||
        return NetDbError::NoData;
 | 
			
		||||
    case GetAddrInfoError::NODATA:
 | 
			
		||||
        return NetDbError::HostNotFound;
 | 
			
		||||
    case GetAddrInfoError::SERVICE:
 | 
			
		||||
        return NetDbError::Success;
 | 
			
		||||
    default:
 | 
			
		||||
        return NetDbError::HostNotFound;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code,
 | 
			
		||||
static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) {
 | 
			
		||||
    // These combinations have been verified on console (but are not
 | 
			
		||||
    // exhaustive).
 | 
			
		||||
    switch (result) {
 | 
			
		||||
    case GetAddrInfoError::SUCCESS:
 | 
			
		||||
        // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for
 | 
			
		||||
        // some reason, but that doesn't seem useful to implement.
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    case GetAddrInfoError::AGAIN:
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    case GetAddrInfoError::NODATA:
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    case GetAddrInfoError::SERVICE:
 | 
			
		||||
        return Errno::INVAL;
 | 
			
		||||
    default:
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static void Append(std::vector<u8>& vec, T t) {
 | 
			
		||||
    const size_t offset = vec.size();
 | 
			
		||||
    vec.resize(offset + sizeof(T));
 | 
			
		||||
    std::memcpy(vec.data() + offset, &t, sizeof(T));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) {
 | 
			
		||||
    const size_t offset = vec.size();
 | 
			
		||||
    vec.resize(offset + str.size() + 1);
 | 
			
		||||
    std::memmove(vec.data() + offset, str.data(), str.size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// We implement gethostbyname using the host's getaddrinfo rather than the
 | 
			
		||||
// host's gethostbyname, because it simplifies portability: e.g., getaddrinfo
 | 
			
		||||
// behaves the same on Unix and Windows, unlike gethostbyname where Windows
 | 
			
		||||
// doesn't implement h_errno.
 | 
			
		||||
static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec,
 | 
			
		||||
                                                  std::string_view host) {
 | 
			
		||||
 | 
			
		||||
    std::vector<u8> data;
 | 
			
		||||
    // h_name: use the input hostname (append nul-terminated)
 | 
			
		||||
    AppendNulTerminated(data, host);
 | 
			
		||||
    // h_aliases: leave empty
 | 
			
		||||
 | 
			
		||||
    Append<u32_be>(data, 0); // count of h_aliases
 | 
			
		||||
    // (If the count were nonzero, the aliases would be appended as nul-terminated here.)
 | 
			
		||||
    Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype
 | 
			
		||||
    Append<u16_be>(data, sizeof(Network::IPv4Address));   // h_length
 | 
			
		||||
    // h_addr_list:
 | 
			
		||||
    size_t count = vec.size();
 | 
			
		||||
    ASSERT(count <= UINT32_MAX);
 | 
			
		||||
    Append<u32_be>(data, static_cast<uint32_t>(count));
 | 
			
		||||
    for (const Network::AddrInfo& addrinfo : vec) {
 | 
			
		||||
        // On the Switch, this is passed through htonl despite already being
 | 
			
		||||
        // big-endian, so it ends up as little-endian.
 | 
			
		||||
        Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip));
 | 
			
		||||
 | 
			
		||||
        LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
 | 
			
		||||
                 Network::IPv4AddressToString(addrinfo.addr.ip));
 | 
			
		||||
    }
 | 
			
		||||
    return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) {
 | 
			
		||||
    struct InputParameters {
 | 
			
		||||
        u8 use_nsd_resolve;
 | 
			
		||||
        u32 cancel_handle;
 | 
			
		||||
        u64 process_id;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(InputParameters) == 0x10);
 | 
			
		||||
 | 
			
		||||
    IPC::RequestParser rp{ctx};
 | 
			
		||||
    const auto parameters = rp.PopRaw<InputParameters>();
 | 
			
		||||
 | 
			
		||||
    LOG_WARNING(
 | 
			
		||||
        Service,
 | 
			
		||||
        "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
 | 
			
		||||
        parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
 | 
			
		||||
 | 
			
		||||
    const auto host_buffer = ctx.ReadBuffer(0);
 | 
			
		||||
    const std::string host = Common::StringFromBuffer(host_buffer);
 | 
			
		||||
    // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions.
 | 
			
		||||
 | 
			
		||||
    auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt);
 | 
			
		||||
    if (!res.has_value()) {
 | 
			
		||||
        return {0, Translate(res.error())};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host);
 | 
			
		||||
    const u32 data_size = static_cast<u32>(data.size());
 | 
			
		||||
    ctx.WriteBuffer(data, 0);
 | 
			
		||||
 | 
			
		||||
    return {data_size, GetAddrInfoError::SUCCESS};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) {
 | 
			
		||||
    auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
 | 
			
		||||
 | 
			
		||||
    struct OutputParameters {
 | 
			
		||||
        NetDbError netdb_error;
 | 
			
		||||
        Errno bsd_errno;
 | 
			
		||||
        u32 data_size;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(OutputParameters) == 0xc);
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 5};
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.PushRaw(OutputParameters{
 | 
			
		||||
        .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
 | 
			
		||||
        .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
 | 
			
		||||
        .data_size = data_size,
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) {
 | 
			
		||||
    auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
 | 
			
		||||
 | 
			
		||||
    struct OutputParameters {
 | 
			
		||||
        u32 data_size;
 | 
			
		||||
        NetDbError netdb_error;
 | 
			
		||||
        Errno bsd_errno;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(OutputParameters) == 0xc);
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 5};
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.PushRaw(OutputParameters{
 | 
			
		||||
        .data_size = data_size,
 | 
			
		||||
        .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
 | 
			
		||||
        .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec,
 | 
			
		||||
                                         std::string_view host) {
 | 
			
		||||
    // Adapted from
 | 
			
		||||
    // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190
 | 
			
		||||
    std::vector<u8> data;
 | 
			
		||||
 | 
			
		||||
    auto* current = addrinfo;
 | 
			
		||||
    while (current != nullptr) {
 | 
			
		||||
        struct SerializedResponseHeader {
 | 
			
		||||
            u32 magic;
 | 
			
		||||
            s32 flags;
 | 
			
		||||
            s32 family;
 | 
			
		||||
            s32 socket_type;
 | 
			
		||||
            s32 protocol;
 | 
			
		||||
            u32 address_length;
 | 
			
		||||
        };
 | 
			
		||||
        static_assert(sizeof(SerializedResponseHeader) == 0x18,
 | 
			
		||||
                      "Response header size must be 0x18 bytes");
 | 
			
		||||
    for (const Network::AddrInfo& addrinfo : vec) {
 | 
			
		||||
        // serialized addrinfo:
 | 
			
		||||
        Append<u32_be>(data, 0xBEEFCAFE);                                        // magic
 | 
			
		||||
        Append<u32_be>(data, 0);                                                 // ai_flags
 | 
			
		||||
        Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family)));      // ai_family
 | 
			
		||||
        Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype
 | 
			
		||||
        Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol)));    // ai_protocol
 | 
			
		||||
        Append<u32_be>(data, sizeof(SockAddrIn));                                // ai_addrlen
 | 
			
		||||
        // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size
 | 
			
		||||
 | 
			
		||||
        constexpr auto header_size = sizeof(SerializedResponseHeader);
 | 
			
		||||
        const auto addr_size =
 | 
			
		||||
            current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4;
 | 
			
		||||
        const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1;
 | 
			
		||||
        // ai_addr:
 | 
			
		||||
        Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family
 | 
			
		||||
        // On the Switch, the following fields are passed through htonl despite
 | 
			
		||||
        // already being big-endian, so they end up as little-endian.
 | 
			
		||||
        Append<u16_le>(data, addrinfo.addr.portno);                            // sin_port
 | 
			
		||||
        Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr
 | 
			
		||||
        data.resize(data.size() + 8, 0);                                       // sin_zero
 | 
			
		||||
 | 
			
		||||
        const auto last_size = data.size();
 | 
			
		||||
        data.resize(last_size + header_size + addr_size + canonname_size);
 | 
			
		||||
 | 
			
		||||
        // Header in network byte order
 | 
			
		||||
        SerializedResponseHeader header{};
 | 
			
		||||
 | 
			
		||||
        constexpr auto HEADER_MAGIC = 0xBEEFCAFE;
 | 
			
		||||
        header.magic = htonl(HEADER_MAGIC);
 | 
			
		||||
        header.family = htonl(current->ai_family);
 | 
			
		||||
        header.flags = htonl(current->ai_flags);
 | 
			
		||||
        header.socket_type = htonl(current->ai_socktype);
 | 
			
		||||
        header.protocol = htonl(current->ai_protocol);
 | 
			
		||||
        header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0;
 | 
			
		||||
 | 
			
		||||
        auto* header_ptr = data.data() + last_size;
 | 
			
		||||
        std::memcpy(header_ptr, &header, header_size);
 | 
			
		||||
 | 
			
		||||
        if (header.address_length == 0) {
 | 
			
		||||
            std::memset(header_ptr + header_size, 0, 4);
 | 
			
		||||
        if (addrinfo.canon_name.has_value()) {
 | 
			
		||||
            AppendNulTerminated(data, *addrinfo.canon_name);
 | 
			
		||||
        } else {
 | 
			
		||||
            switch (current->ai_family) {
 | 
			
		||||
            case AF_INET: {
 | 
			
		||||
                struct SockAddrIn {
 | 
			
		||||
                    s16 sin_family;
 | 
			
		||||
                    u16 sin_port;
 | 
			
		||||
                    u32 sin_addr;
 | 
			
		||||
                    u8 sin_zero[8];
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                SockAddrIn serialized_addr{};
 | 
			
		||||
                const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr);
 | 
			
		||||
                serialized_addr.sin_port = htons(addr.sin_port);
 | 
			
		||||
                serialized_addr.sin_family = htons(addr.sin_family);
 | 
			
		||||
                serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr);
 | 
			
		||||
                std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn));
 | 
			
		||||
 | 
			
		||||
                char addr_string_buf[64]{};
 | 
			
		||||
                inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf));
 | 
			
		||||
                LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf);
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            case AF_INET6: {
 | 
			
		||||
                struct SockAddrIn6 {
 | 
			
		||||
                    s16 sin6_family;
 | 
			
		||||
                    u16 sin6_port;
 | 
			
		||||
                    u32 sin6_flowinfo;
 | 
			
		||||
                    u8 sin6_addr[16];
 | 
			
		||||
                    u32 sin6_scope_id;
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                SockAddrIn6 serialized_addr{};
 | 
			
		||||
                const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr);
 | 
			
		||||
                serialized_addr.sin6_family = htons(addr.sin6_family);
 | 
			
		||||
                serialized_addr.sin6_port = htons(addr.sin6_port);
 | 
			
		||||
                serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo);
 | 
			
		||||
                serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id);
 | 
			
		||||
                std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr,
 | 
			
		||||
                            sizeof(SockAddrIn6::sin6_addr));
 | 
			
		||||
                std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6));
 | 
			
		||||
 | 
			
		||||
                char addr_string_buf[64]{};
 | 
			
		||||
                inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf));
 | 
			
		||||
                LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf);
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            default:
 | 
			
		||||
                std::memcpy(header_ptr + header_size, current->ai_addr, addr_size);
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if (current->ai_canonname) {
 | 
			
		||||
            std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size);
 | 
			
		||||
        } else {
 | 
			
		||||
            *(header_ptr + header_size + addr_size) = 0;
 | 
			
		||||
            data.push_back(0);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        current = current->ai_next;
 | 
			
		||||
        LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
 | 
			
		||||
                 Network::IPv4AddressToString(addrinfo.addr.ip));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // 4-byte sentinel value
 | 
			
		||||
    data.push_back(0);
 | 
			
		||||
    data.push_back(0);
 | 
			
		||||
    data.push_back(0);
 | 
			
		||||
    data.push_back(0);
 | 
			
		||||
    data.resize(data.size() + 4, 0); // 4-byte sentinel value
 | 
			
		||||
 | 
			
		||||
    return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
 | 
			
		||||
    struct Parameters {
 | 
			
		||||
static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
 | 
			
		||||
    struct InputParameters {
 | 
			
		||||
        u8 use_nsd_resolve;
 | 
			
		||||
        u32 unknown;
 | 
			
		||||
        u32 cancel_handle;
 | 
			
		||||
        u64 process_id;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(InputParameters) == 0x10);
 | 
			
		||||
 | 
			
		||||
    IPC::RequestParser rp{ctx};
 | 
			
		||||
    const auto parameters = rp.PopRaw<Parameters>();
 | 
			
		||||
    const auto parameters = rp.PopRaw<InputParameters>();
 | 
			
		||||
 | 
			
		||||
    LOG_WARNING(Service,
 | 
			
		||||
                "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}",
 | 
			
		||||
                parameters.use_nsd_resolve, parameters.unknown, parameters.process_id);
 | 
			
		||||
    LOG_WARNING(
 | 
			
		||||
        Service,
 | 
			
		||||
        "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
 | 
			
		||||
        parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
 | 
			
		||||
 | 
			
		||||
    // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve
 | 
			
		||||
    // before looking up.
 | 
			
		||||
 | 
			
		||||
    const auto host_buffer = ctx.ReadBuffer(0);
 | 
			
		||||
    const std::string host = Common::StringFromBuffer(host_buffer);
 | 
			
		||||
 | 
			
		||||
    const auto service_buffer = ctx.ReadBuffer(1);
 | 
			
		||||
    const std::string service = Common::StringFromBuffer(service_buffer);
 | 
			
		||||
 | 
			
		||||
    addrinfo* addrinfo;
 | 
			
		||||
    // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now
 | 
			
		||||
    s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo);
 | 
			
		||||
 | 
			
		||||
    u32 data_size = 0;
 | 
			
		||||
    if (result_code == 0 && addrinfo != nullptr) {
 | 
			
		||||
        const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host);
 | 
			
		||||
        data_size = static_cast<u32>(data.size());
 | 
			
		||||
        freeaddrinfo(addrinfo);
 | 
			
		||||
 | 
			
		||||
        ctx.WriteBuffer(data, 0);
 | 
			
		||||
    std::optional<std::string> service = std::nullopt;
 | 
			
		||||
    if (ctx.CanReadBuffer(1)) {
 | 
			
		||||
        const std::span<const u8> service_buffer = ctx.ReadBuffer(1);
 | 
			
		||||
        service = Common::StringFromBuffer(service_buffer);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return std::make_pair(data_size, result_code);
 | 
			
		||||
    // Serialized hints are also passed in a buffer, but are ignored for now.
 | 
			
		||||
 | 
			
		||||
    auto res = Network::GetAddressInfo(host, service);
 | 
			
		||||
    if (!res.has_value()) {
 | 
			
		||||
        return {0, Translate(res.error())};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const std::vector<u8> data = SerializeAddrInfo(res.value(), host);
 | 
			
		||||
    const u32 data_size = static_cast<u32>(data.size());
 | 
			
		||||
    ctx.WriteBuffer(data, 0);
 | 
			
		||||
 | 
			
		||||
    return {data_size, GetAddrInfoError::SUCCESS};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) {
 | 
			
		||||
    auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx);
 | 
			
		||||
    auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 4};
 | 
			
		||||
    struct OutputParameters {
 | 
			
		||||
        Errno bsd_errno;
 | 
			
		||||
        GetAddrInfoError gai_error;
 | 
			
		||||
        u32 data_size;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(OutputParameters) == 0xc);
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 5};
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
 | 
			
		||||
    rb.Push(result_code);                                              // errno
 | 
			
		||||
    rb.Push(data_size);                                                // serialized size
 | 
			
		||||
    rb.PushRaw(OutputParameters{
 | 
			
		||||
        .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
 | 
			
		||||
        .gai_error = emu_gai_err,
 | 
			
		||||
        .data_size = data_size,
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) {
 | 
			
		||||
    // Additional options are ignored
 | 
			
		||||
    auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx);
 | 
			
		||||
    auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 5};
 | 
			
		||||
    struct OutputParameters {
 | 
			
		||||
        u32 data_size;
 | 
			
		||||
        GetAddrInfoError gai_error;
 | 
			
		||||
        NetDbError netdb_error;
 | 
			
		||||
        Errno bsd_errno;
 | 
			
		||||
    };
 | 
			
		||||
    static_assert(sizeof(OutputParameters) == 0x10);
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 6};
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.Push(data_size);                                                // serialized size
 | 
			
		||||
    rb.Push(result_code);                                              // errno
 | 
			
		||||
    rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
 | 
			
		||||
    rb.Push(0);
 | 
			
		||||
    rb.PushRaw(OutputParameters{
 | 
			
		||||
        .data_size = data_size,
 | 
			
		||||
        .gai_error = emu_gai_err,
 | 
			
		||||
        .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
 | 
			
		||||
        .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) {
 | 
			
		||||
    LOG_WARNING(Service, "(STUBBED) called");
 | 
			
		||||
 | 
			
		||||
    IPC::ResponseBuilder rb{ctx, 3};
 | 
			
		||||
 | 
			
		||||
    rb.Push(ResultSuccess);
 | 
			
		||||
    rb.Push<s32>(0); // bsd errno
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Service::Sockets
 | 
			
		||||
 
 | 
			
		||||
@@ -17,8 +17,11 @@ public:
 | 
			
		||||
    ~SFDNSRES() override;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    void GetHostByNameRequest(HLERequestContext& ctx);
 | 
			
		||||
    void GetHostByNameRequestWithOptions(HLERequestContext& ctx);
 | 
			
		||||
    void GetAddrInfoRequest(HLERequestContext& ctx);
 | 
			
		||||
    void GetAddrInfoRequestWithOptions(HLERequestContext& ctx);
 | 
			
		||||
    void ResolverSetOptionRequest(HLERequestContext& ctx);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace Service::Sockets
 | 
			
		||||
 
 | 
			
		||||
@@ -22,13 +22,35 @@ enum class Errno : u32 {
 | 
			
		||||
    CONNRESET = 104,
 | 
			
		||||
    NOTCONN = 107,
 | 
			
		||||
    TIMEDOUT = 110,
 | 
			
		||||
    INPROGRESS = 115,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class GetAddrInfoError : s32 {
 | 
			
		||||
    SUCCESS = 0,
 | 
			
		||||
    ADDRFAMILY = 1,
 | 
			
		||||
    AGAIN = 2,
 | 
			
		||||
    BADFLAGS = 3,
 | 
			
		||||
    FAIL = 4,
 | 
			
		||||
    FAMILY = 5,
 | 
			
		||||
    MEMORY = 6,
 | 
			
		||||
    NODATA = 7,
 | 
			
		||||
    NONAME = 8,
 | 
			
		||||
    SERVICE = 9,
 | 
			
		||||
    SOCKTYPE = 10,
 | 
			
		||||
    SYSTEM = 11,
 | 
			
		||||
    BADHINTS = 12,
 | 
			
		||||
    PROTOCOL = 13,
 | 
			
		||||
    OVERFLOW_ = 14, // avoid name collision with Windows macro
 | 
			
		||||
    OTHER = 15,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class Domain : u32 {
 | 
			
		||||
    Unspecified = 0,
 | 
			
		||||
    INET = 2,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class Type : u32 {
 | 
			
		||||
    Unspecified = 0,
 | 
			
		||||
    STREAM = 1,
 | 
			
		||||
    DGRAM = 2,
 | 
			
		||||
    RAW = 3,
 | 
			
		||||
@@ -36,12 +58,16 @@ enum class Type : u32 {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class Protocol : u32 {
 | 
			
		||||
    UNSPECIFIED = 0,
 | 
			
		||||
    Unspecified = 0,
 | 
			
		||||
    ICMP = 1,
 | 
			
		||||
    TCP = 6,
 | 
			
		||||
    UDP = 17,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class SocketLevel : u32 {
 | 
			
		||||
    SOCKET = 0xffff, // i.e. SOL_SOCKET
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class OptName : u32 {
 | 
			
		||||
    REUSEADDR = 0x4,
 | 
			
		||||
    KEEPALIVE = 0x8,
 | 
			
		||||
@@ -51,6 +77,8 @@ enum class OptName : u32 {
 | 
			
		||||
    RCVBUF = 0x1002,
 | 
			
		||||
    SNDTIMEO = 0x1005,
 | 
			
		||||
    RCVTIMEO = 0x1006,
 | 
			
		||||
    ERROR_ = 0x1007,   // avoid name collision with Windows macro
 | 
			
		||||
    NOSIGPIPE = 0x800, // at least according to libnx
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class ShutdownHow : s32 {
 | 
			
		||||
@@ -80,6 +108,9 @@ enum class PollEvents : u16 {
 | 
			
		||||
    Err = 1 << 3,
 | 
			
		||||
    Hup = 1 << 4,
 | 
			
		||||
    Nval = 1 << 5,
 | 
			
		||||
    RdNorm = 1 << 6,
 | 
			
		||||
    RdBand = 1 << 7,
 | 
			
		||||
    WrBand = 1 << 8,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
 | 
			
		||||
 
 | 
			
		||||
@@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) {
 | 
			
		||||
        return Errno::TIMEDOUT;
 | 
			
		||||
    case Network::Errno::CONNRESET:
 | 
			
		||||
        return Errno::CONNRESET;
 | 
			
		||||
    case Network::Errno::INPROGRESS:
 | 
			
		||||
        return Errno::INPROGRESS;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented errno={}", value);
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
@@ -39,8 +41,50 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) {
 | 
			
		||||
    return {value.first, Translate(value.second)};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
GetAddrInfoError Translate(Network::GetAddrInfoError error) {
 | 
			
		||||
    switch (error) {
 | 
			
		||||
    case Network::GetAddrInfoError::SUCCESS:
 | 
			
		||||
        return GetAddrInfoError::SUCCESS;
 | 
			
		||||
    case Network::GetAddrInfoError::ADDRFAMILY:
 | 
			
		||||
        return GetAddrInfoError::ADDRFAMILY;
 | 
			
		||||
    case Network::GetAddrInfoError::AGAIN:
 | 
			
		||||
        return GetAddrInfoError::AGAIN;
 | 
			
		||||
    case Network::GetAddrInfoError::BADFLAGS:
 | 
			
		||||
        return GetAddrInfoError::BADFLAGS;
 | 
			
		||||
    case Network::GetAddrInfoError::FAIL:
 | 
			
		||||
        return GetAddrInfoError::FAIL;
 | 
			
		||||
    case Network::GetAddrInfoError::FAMILY:
 | 
			
		||||
        return GetAddrInfoError::FAMILY;
 | 
			
		||||
    case Network::GetAddrInfoError::MEMORY:
 | 
			
		||||
        return GetAddrInfoError::MEMORY;
 | 
			
		||||
    case Network::GetAddrInfoError::NODATA:
 | 
			
		||||
        return GetAddrInfoError::NODATA;
 | 
			
		||||
    case Network::GetAddrInfoError::NONAME:
 | 
			
		||||
        return GetAddrInfoError::NONAME;
 | 
			
		||||
    case Network::GetAddrInfoError::SERVICE:
 | 
			
		||||
        return GetAddrInfoError::SERVICE;
 | 
			
		||||
    case Network::GetAddrInfoError::SOCKTYPE:
 | 
			
		||||
        return GetAddrInfoError::SOCKTYPE;
 | 
			
		||||
    case Network::GetAddrInfoError::SYSTEM:
 | 
			
		||||
        return GetAddrInfoError::SYSTEM;
 | 
			
		||||
    case Network::GetAddrInfoError::BADHINTS:
 | 
			
		||||
        return GetAddrInfoError::BADHINTS;
 | 
			
		||||
    case Network::GetAddrInfoError::PROTOCOL:
 | 
			
		||||
        return GetAddrInfoError::PROTOCOL;
 | 
			
		||||
    case Network::GetAddrInfoError::OVERFLOW_:
 | 
			
		||||
        return GetAddrInfoError::OVERFLOW_;
 | 
			
		||||
    case Network::GetAddrInfoError::OTHER:
 | 
			
		||||
        return GetAddrInfoError::OTHER;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error);
 | 
			
		||||
        return GetAddrInfoError::OTHER;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Network::Domain Translate(Domain domain) {
 | 
			
		||||
    switch (domain) {
 | 
			
		||||
    case Domain::Unspecified:
 | 
			
		||||
        return Network::Domain::Unspecified;
 | 
			
		||||
    case Domain::INET:
 | 
			
		||||
        return Network::Domain::INET;
 | 
			
		||||
    default:
 | 
			
		||||
@@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) {
 | 
			
		||||
 | 
			
		||||
Domain Translate(Network::Domain domain) {
 | 
			
		||||
    switch (domain) {
 | 
			
		||||
    case Network::Domain::Unspecified:
 | 
			
		||||
        return Domain::Unspecified;
 | 
			
		||||
    case Network::Domain::INET:
 | 
			
		||||
        return Domain::INET;
 | 
			
		||||
    default:
 | 
			
		||||
@@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) {
 | 
			
		||||
 | 
			
		||||
Network::Type Translate(Type type) {
 | 
			
		||||
    switch (type) {
 | 
			
		||||
    case Type::Unspecified:
 | 
			
		||||
        return Network::Type::Unspecified;
 | 
			
		||||
    case Type::STREAM:
 | 
			
		||||
        return Network::Type::STREAM;
 | 
			
		||||
    case Type::DGRAM:
 | 
			
		||||
        return Network::Type::DGRAM;
 | 
			
		||||
    case Type::RAW:
 | 
			
		||||
        return Network::Type::RAW;
 | 
			
		||||
    case Type::SEQPACKET:
 | 
			
		||||
        return Network::Type::SEQPACKET;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented type={}", type);
 | 
			
		||||
        return Network::Type{};
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Network::Protocol Translate(Type type, Protocol protocol) {
 | 
			
		||||
Type Translate(Network::Type type) {
 | 
			
		||||
    switch (type) {
 | 
			
		||||
    case Network::Type::Unspecified:
 | 
			
		||||
        return Type::Unspecified;
 | 
			
		||||
    case Network::Type::STREAM:
 | 
			
		||||
        return Type::STREAM;
 | 
			
		||||
    case Network::Type::DGRAM:
 | 
			
		||||
        return Type::DGRAM;
 | 
			
		||||
    case Network::Type::RAW:
 | 
			
		||||
        return Type::RAW;
 | 
			
		||||
    case Network::Type::SEQPACKET:
 | 
			
		||||
        return Type::SEQPACKET;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented type={}", type);
 | 
			
		||||
        return Type{};
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Network::Protocol Translate(Protocol protocol) {
 | 
			
		||||
    switch (protocol) {
 | 
			
		||||
    case Protocol::UNSPECIFIED:
 | 
			
		||||
        LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type");
 | 
			
		||||
        switch (type) {
 | 
			
		||||
        case Type::DGRAM:
 | 
			
		||||
            return Network::Protocol::UDP;
 | 
			
		||||
        case Type::STREAM:
 | 
			
		||||
            return Network::Protocol::TCP;
 | 
			
		||||
        default:
 | 
			
		||||
            return Network::Protocol::TCP;
 | 
			
		||||
        }
 | 
			
		||||
    case Protocol::Unspecified:
 | 
			
		||||
        return Network::Protocol::Unspecified;
 | 
			
		||||
    case Protocol::TCP:
 | 
			
		||||
        return Network::Protocol::TCP;
 | 
			
		||||
    case Protocol::UDP:
 | 
			
		||||
        return Network::Protocol::UDP;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
 | 
			
		||||
        return Network::Protocol::TCP;
 | 
			
		||||
        return Network::Protocol::Unspecified;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
 | 
			
		||||
Protocol Translate(Network::Protocol protocol) {
 | 
			
		||||
    switch (protocol) {
 | 
			
		||||
    case Network::Protocol::Unspecified:
 | 
			
		||||
        return Protocol::Unspecified;
 | 
			
		||||
    case Network::Protocol::TCP:
 | 
			
		||||
        return Protocol::TCP;
 | 
			
		||||
    case Network::Protocol::UDP:
 | 
			
		||||
        return Protocol::UDP;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
 | 
			
		||||
        return Protocol::Unspecified;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Network::PollEvents Translate(PollEvents flags) {
 | 
			
		||||
    Network::PollEvents result{};
 | 
			
		||||
    const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) {
 | 
			
		||||
        if (True(flags & from)) {
 | 
			
		||||
@@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
 | 
			
		||||
    translate(PollEvents::Err, Network::PollEvents::Err);
 | 
			
		||||
    translate(PollEvents::Hup, Network::PollEvents::Hup);
 | 
			
		||||
    translate(PollEvents::Nval, Network::PollEvents::Nval);
 | 
			
		||||
    translate(PollEvents::RdNorm, Network::PollEvents::RdNorm);
 | 
			
		||||
    translate(PollEvents::RdBand, Network::PollEvents::RdBand);
 | 
			
		||||
    translate(PollEvents::WrBand, Network::PollEvents::WrBand);
 | 
			
		||||
 | 
			
		||||
    UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
 | 
			
		||||
    return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
 | 
			
		||||
PollEvents Translate(Network::PollEvents flags) {
 | 
			
		||||
    PollEvents result{};
 | 
			
		||||
    const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) {
 | 
			
		||||
        if (True(flags & from)) {
 | 
			
		||||
@@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
 | 
			
		||||
    translate(Network::PollEvents::Err, PollEvents::Err);
 | 
			
		||||
    translate(Network::PollEvents::Hup, PollEvents::Hup);
 | 
			
		||||
    translate(Network::PollEvents::Nval, PollEvents::Nval);
 | 
			
		||||
    translate(Network::PollEvents::RdNorm, PollEvents::RdNorm);
 | 
			
		||||
    translate(Network::PollEvents::RdBand, PollEvents::RdBand);
 | 
			
		||||
    translate(Network::PollEvents::WrBand, PollEvents::WrBand);
 | 
			
		||||
 | 
			
		||||
    UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
 | 
			
		||||
    return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Network::SockAddrIn Translate(SockAddrIn value) {
 | 
			
		||||
    ASSERT(value.len == 0 || value.len == sizeof(value));
 | 
			
		||||
    // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets
 | 
			
		||||
    // sin_len to 6 when deserializing getaddrinfo results).
 | 
			
		||||
    ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6);
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        .family = Translate(static_cast<Domain>(value.family)),
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,9 @@ Errno Translate(Network::Errno value);
 | 
			
		||||
/// Translate abstract return value errno pair to guest return value errno pair
 | 
			
		||||
std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value);
 | 
			
		||||
 | 
			
		||||
/// Translate abstract getaddrinfo error to guest getaddrinfo error
 | 
			
		||||
GetAddrInfoError Translate(Network::GetAddrInfoError value);
 | 
			
		||||
 | 
			
		||||
/// Translate guest domain to abstract domain
 | 
			
		||||
Network::Domain Translate(Domain domain);
 | 
			
		||||
 | 
			
		||||
@@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain);
 | 
			
		||||
/// Translate guest type to abstract type
 | 
			
		||||
Network::Type Translate(Type type);
 | 
			
		||||
 | 
			
		||||
/// Translate guest protocol to abstract protocol
 | 
			
		||||
Network::Protocol Translate(Type type, Protocol protocol);
 | 
			
		||||
/// Translate abstract type to guest type
 | 
			
		||||
Type Translate(Network::Type type);
 | 
			
		||||
 | 
			
		||||
/// Translate abstract poll event flags to guest poll event flags
 | 
			
		||||
Network::PollEvents TranslatePollEventsToHost(PollEvents flags);
 | 
			
		||||
/// Translate guest protocol to abstract protocol
 | 
			
		||||
Network::Protocol Translate(Protocol protocol);
 | 
			
		||||
 | 
			
		||||
/// Translate abstract protocol to guest protocol
 | 
			
		||||
Protocol Translate(Network::Protocol protocol);
 | 
			
		||||
 | 
			
		||||
/// Translate guest poll event flags to abstract poll event flags
 | 
			
		||||
PollEvents TranslatePollEventsToGuest(Network::PollEvents flags);
 | 
			
		||||
Network::PollEvents Translate(PollEvents flags);
 | 
			
		||||
 | 
			
		||||
/// Translate abstract poll event flags to guest poll event flags
 | 
			
		||||
PollEvents Translate(Network::PollEvents flags);
 | 
			
		||||
 | 
			
		||||
/// Translate guest socket address structure to abstract socket address structure
 | 
			
		||||
Network::SockAddrIn Translate(SockAddrIn value);
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,18 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "common/string_util.h"
 | 
			
		||||
 | 
			
		||||
#include "core/core.h"
 | 
			
		||||
#include "core/hle/service/ipc_helpers.h"
 | 
			
		||||
#include "core/hle/service/server_manager.h"
 | 
			
		||||
#include "core/hle/service/service.h"
 | 
			
		||||
#include "core/hle/service/sm/sm.h"
 | 
			
		||||
#include "core/hle/service/sockets/bsd.h"
 | 
			
		||||
#include "core/hle/service/ssl/ssl.h"
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
 | 
			
		||||
    CrlImportDateCheckEnable = 1,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// This is nn::ssl::Connection::IoMode
 | 
			
		||||
enum class IoMode : u32 {
 | 
			
		||||
    Blocking = 1,
 | 
			
		||||
    NonBlocking = 2,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// This is nn::ssl::sf::OptionType
 | 
			
		||||
enum class OptionType : u32 {
 | 
			
		||||
    DoNotCloseSocket = 0,
 | 
			
		||||
    GetServerCertChain = 1,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// This is nn::ssl::sf::SslVersion
 | 
			
		||||
struct SslVersion {
 | 
			
		||||
    union {
 | 
			
		||||
@@ -34,35 +54,42 @@ struct SslVersion {
 | 
			
		||||
    };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct SslContextSharedData {
 | 
			
		||||
    u32 connection_count = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class ISslConnection final : public ServiceFramework<ISslConnection> {
 | 
			
		||||
public:
 | 
			
		||||
    explicit ISslConnection(Core::System& system_, SslVersion version)
 | 
			
		||||
        : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} {
 | 
			
		||||
    explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in,
 | 
			
		||||
                            std::shared_ptr<SslContextSharedData>& shared_data_in,
 | 
			
		||||
                            std::unique_ptr<SSLConnectionBackend>&& backend_in)
 | 
			
		||||
        : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in},
 | 
			
		||||
          shared_data{shared_data_in}, backend{std::move(backend_in)} {
 | 
			
		||||
        // clang-format off
 | 
			
		||||
        static const FunctionInfo functions[] = {
 | 
			
		||||
            {0, nullptr, "SetSocketDescriptor"},
 | 
			
		||||
            {1, nullptr, "SetHostName"},
 | 
			
		||||
            {2, nullptr, "SetVerifyOption"},
 | 
			
		||||
            {3, nullptr, "SetIoMode"},
 | 
			
		||||
            {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
 | 
			
		||||
            {1, &ISslConnection::SetHostName, "SetHostName"},
 | 
			
		||||
            {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
 | 
			
		||||
            {3, &ISslConnection::SetIoMode, "SetIoMode"},
 | 
			
		||||
            {4, nullptr, "GetSocketDescriptor"},
 | 
			
		||||
            {5, nullptr, "GetHostName"},
 | 
			
		||||
            {6, nullptr, "GetVerifyOption"},
 | 
			
		||||
            {7, nullptr, "GetIoMode"},
 | 
			
		||||
            {8, nullptr, "DoHandshake"},
 | 
			
		||||
            {9, nullptr, "DoHandshakeGetServerCert"},
 | 
			
		||||
            {10, nullptr, "Read"},
 | 
			
		||||
            {11, nullptr, "Write"},
 | 
			
		||||
            {12, nullptr, "Pending"},
 | 
			
		||||
            {8, &ISslConnection::DoHandshake, "DoHandshake"},
 | 
			
		||||
            {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
 | 
			
		||||
            {10, &ISslConnection::Read, "Read"},
 | 
			
		||||
            {11, &ISslConnection::Write, "Write"},
 | 
			
		||||
            {12, &ISslConnection::Pending, "Pending"},
 | 
			
		||||
            {13, nullptr, "Peek"},
 | 
			
		||||
            {14, nullptr, "Poll"},
 | 
			
		||||
            {15, nullptr, "GetVerifyCertError"},
 | 
			
		||||
            {16, nullptr, "GetNeededServerCertBufferSize"},
 | 
			
		||||
            {17, nullptr, "SetSessionCacheMode"},
 | 
			
		||||
            {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
 | 
			
		||||
            {18, nullptr, "GetSessionCacheMode"},
 | 
			
		||||
            {19, nullptr, "FlushSessionCache"},
 | 
			
		||||
            {20, nullptr, "SetRenegotiationMode"},
 | 
			
		||||
            {21, nullptr, "GetRenegotiationMode"},
 | 
			
		||||
            {22, nullptr, "SetOption"},
 | 
			
		||||
            {22, &ISslConnection::SetOption, "SetOption"},
 | 
			
		||||
            {23, nullptr, "GetOption"},
 | 
			
		||||
            {24, nullptr, "GetVerifyCertErrors"},
 | 
			
		||||
            {25, nullptr, "GetCipherInfo"},
 | 
			
		||||
@@ -80,21 +107,299 @@ public:
 | 
			
		||||
        // clang-format on
 | 
			
		||||
 | 
			
		||||
        RegisterHandlers(functions);
 | 
			
		||||
 | 
			
		||||
        shared_data->connection_count++;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~ISslConnection() {
 | 
			
		||||
        shared_data->connection_count--;
 | 
			
		||||
        if (fd_to_close.has_value()) {
 | 
			
		||||
            const s32 fd = *fd_to_close;
 | 
			
		||||
            if (!do_not_close_socket) {
 | 
			
		||||
                LOG_ERROR(Service_SSL,
 | 
			
		||||
                          "do_not_close_socket was changed after setting socket; is this right?");
 | 
			
		||||
            } else {
 | 
			
		||||
                auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
 | 
			
		||||
                if (bsd) {
 | 
			
		||||
                    auto err = bsd->CloseImpl(fd);
 | 
			
		||||
                    if (err != Service::Sockets::Errno::SUCCESS) {
 | 
			
		||||
                        LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    SslVersion ssl_version;
 | 
			
		||||
    std::shared_ptr<SslContextSharedData> shared_data;
 | 
			
		||||
    std::unique_ptr<SSLConnectionBackend> backend;
 | 
			
		||||
    std::optional<int> fd_to_close;
 | 
			
		||||
    bool do_not_close_socket = false;
 | 
			
		||||
    bool get_server_cert_chain = false;
 | 
			
		||||
    std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
    bool did_set_host_name = false;
 | 
			
		||||
    bool did_handshake = false;
 | 
			
		||||
 | 
			
		||||
    ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
 | 
			
		||||
        LOG_DEBUG(Service_SSL, "called, fd={}", fd);
 | 
			
		||||
        ASSERT(!did_handshake);
 | 
			
		||||
        auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
 | 
			
		||||
        ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
 | 
			
		||||
        s32 ret_fd;
 | 
			
		||||
        // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
 | 
			
		||||
        if (do_not_close_socket) {
 | 
			
		||||
            auto res = bsd->DuplicateSocketImpl(fd);
 | 
			
		||||
            if (!res.has_value()) {
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
 | 
			
		||||
                return ResultInvalidSocket;
 | 
			
		||||
            }
 | 
			
		||||
            fd = *res;
 | 
			
		||||
            fd_to_close = fd;
 | 
			
		||||
            ret_fd = fd;
 | 
			
		||||
        } else {
 | 
			
		||||
            ret_fd = -1;
 | 
			
		||||
        }
 | 
			
		||||
        std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
 | 
			
		||||
        if (!sock.has_value()) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
 | 
			
		||||
            return ResultInvalidSocket;
 | 
			
		||||
        }
 | 
			
		||||
        socket = std::move(*sock);
 | 
			
		||||
        backend->SetSocket(socket);
 | 
			
		||||
        return ret_fd;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetHostNameImpl(const std::string& hostname) {
 | 
			
		||||
        LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
 | 
			
		||||
        ASSERT(!did_handshake);
 | 
			
		||||
        Result res = backend->SetHostName(hostname);
 | 
			
		||||
        if (res == ResultSuccess) {
 | 
			
		||||
            did_set_host_name = true;
 | 
			
		||||
        }
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetVerifyOptionImpl(u32 option) {
 | 
			
		||||
        ASSERT(!did_handshake);
 | 
			
		||||
        LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetIoModeImpl(u32 input_mode) {
 | 
			
		||||
        auto mode = static_cast<IoMode>(input_mode);
 | 
			
		||||
        ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
 | 
			
		||||
        ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
 | 
			
		||||
 | 
			
		||||
        const bool non_block = mode == IoMode::NonBlocking;
 | 
			
		||||
        const Network::Errno error = socket->SetNonBlock(non_block);
 | 
			
		||||
        if (error != Network::Errno::SUCCESS) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetSessionCacheModeImpl(u32 mode) {
 | 
			
		||||
        ASSERT(!did_handshake);
 | 
			
		||||
        LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result DoHandshakeImpl() {
 | 
			
		||||
        ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            did_set_host_name, { return ResultInternalError; },
 | 
			
		||||
            "Expected SetHostName before DoHandshake");
 | 
			
		||||
        Result res = backend->DoHandshake();
 | 
			
		||||
        did_handshake = res.IsSuccess();
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
 | 
			
		||||
        struct Header {
 | 
			
		||||
            u64 magic;
 | 
			
		||||
            u32 count;
 | 
			
		||||
            u32 pad;
 | 
			
		||||
        };
 | 
			
		||||
        struct EntryHeader {
 | 
			
		||||
            u32 size;
 | 
			
		||||
            u32 offset;
 | 
			
		||||
        };
 | 
			
		||||
        if (!get_server_cert_chain) {
 | 
			
		||||
            // Just return the first one, unencoded.
 | 
			
		||||
            ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
                !certs.empty(), { return {}; }, "Should be at least one server cert");
 | 
			
		||||
            return certs[0];
 | 
			
		||||
        }
 | 
			
		||||
        std::vector<u8> ret;
 | 
			
		||||
        Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
 | 
			
		||||
        ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
 | 
			
		||||
        size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
 | 
			
		||||
        for (auto& cert : certs) {
 | 
			
		||||
            EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
 | 
			
		||||
            data_offset += cert.size();
 | 
			
		||||
            ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
 | 
			
		||||
                       reinterpret_cast<u8*>(&entry_header + 1));
 | 
			
		||||
        }
 | 
			
		||||
        for (auto& cert : certs) {
 | 
			
		||||
            ret.insert(ret.end(), cert.begin(), cert.end());
 | 
			
		||||
        }
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<std::vector<u8>> ReadImpl(size_t size) {
 | 
			
		||||
        ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
 | 
			
		||||
        std::vector<u8> res(size);
 | 
			
		||||
        ResultVal<size_t> actual = backend->Read(res);
 | 
			
		||||
        if (actual.Failed()) {
 | 
			
		||||
            return actual.Code();
 | 
			
		||||
        }
 | 
			
		||||
        res.resize(*actual);
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> WriteImpl(std::span<const u8> data) {
 | 
			
		||||
        ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
 | 
			
		||||
        return backend->Write(data);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<s32> PendingImpl() {
 | 
			
		||||
        LOG_WARNING(Service_SSL, "(STUBBED) called.");
 | 
			
		||||
        return 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSocketDescriptor(HLERequestContext& ctx) {
 | 
			
		||||
        IPC::RequestParser rp{ctx};
 | 
			
		||||
        const s32 fd = rp.Pop<s32>();
 | 
			
		||||
        const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 3};
 | 
			
		||||
        rb.Push(res.Code());
 | 
			
		||||
        rb.Push<s32>(res.ValueOr(-1));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetHostName(HLERequestContext& ctx) {
 | 
			
		||||
        const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
 | 
			
		||||
        const Result res = SetHostNameImpl(hostname);
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
        rb.Push(res);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetVerifyOption(HLERequestContext& ctx) {
 | 
			
		||||
        IPC::RequestParser rp{ctx};
 | 
			
		||||
        const u32 option = rp.Pop<u32>();
 | 
			
		||||
        const Result res = SetVerifyOptionImpl(option);
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
        rb.Push(res);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetIoMode(HLERequestContext& ctx) {
 | 
			
		||||
        IPC::RequestParser rp{ctx};
 | 
			
		||||
        const u32 mode = rp.Pop<u32>();
 | 
			
		||||
        const Result res = SetIoModeImpl(mode);
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
        rb.Push(res);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void DoHandshake(HLERequestContext& ctx) {
 | 
			
		||||
        const Result res = DoHandshakeImpl();
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
        rb.Push(res);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void DoHandshakeGetServerCert(HLERequestContext& ctx) {
 | 
			
		||||
        struct OutputParameters {
 | 
			
		||||
            u32 certs_size;
 | 
			
		||||
            u32 certs_count;
 | 
			
		||||
        };
 | 
			
		||||
        static_assert(sizeof(OutputParameters) == 0x8);
 | 
			
		||||
 | 
			
		||||
        const Result res = DoHandshakeImpl();
 | 
			
		||||
        OutputParameters out{};
 | 
			
		||||
        if (res == ResultSuccess) {
 | 
			
		||||
            auto certs = backend->GetServerCerts();
 | 
			
		||||
            if (certs.Succeeded()) {
 | 
			
		||||
                const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
 | 
			
		||||
                ctx.WriteBuffer(certs_buf);
 | 
			
		||||
                out.certs_count = static_cast<u32>(certs->size());
 | 
			
		||||
                out.certs_size = static_cast<u32>(certs_buf.size());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 4};
 | 
			
		||||
        rb.Push(res);
 | 
			
		||||
        rb.PushRaw(out);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void Read(HLERequestContext& ctx) {
 | 
			
		||||
        const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 3};
 | 
			
		||||
        rb.Push(res.Code());
 | 
			
		||||
        if (res.Succeeded()) {
 | 
			
		||||
            rb.Push(static_cast<u32>(res->size()));
 | 
			
		||||
            ctx.WriteBuffer(*res);
 | 
			
		||||
        } else {
 | 
			
		||||
            rb.Push(static_cast<u32>(0));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void Write(HLERequestContext& ctx) {
 | 
			
		||||
        const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 3};
 | 
			
		||||
        rb.Push(res.Code());
 | 
			
		||||
        rb.Push(static_cast<u32>(res.ValueOr(0)));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void Pending(HLERequestContext& ctx) {
 | 
			
		||||
        const ResultVal<s32> res = PendingImpl();
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 3};
 | 
			
		||||
        rb.Push(res.Code());
 | 
			
		||||
        rb.Push<s32>(res.ValueOr(0));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSessionCacheMode(HLERequestContext& ctx) {
 | 
			
		||||
        IPC::RequestParser rp{ctx};
 | 
			
		||||
        const u32 mode = rp.Pop<u32>();
 | 
			
		||||
        const Result res = SetSessionCacheModeImpl(mode);
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
        rb.Push(res);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetOption(HLERequestContext& ctx) {
 | 
			
		||||
        struct Parameters {
 | 
			
		||||
            OptionType option;
 | 
			
		||||
            s32 value;
 | 
			
		||||
        };
 | 
			
		||||
        static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
 | 
			
		||||
 | 
			
		||||
        IPC::RequestParser rp{ctx};
 | 
			
		||||
        const auto parameters = rp.PopRaw<Parameters>();
 | 
			
		||||
 | 
			
		||||
        switch (parameters.option) {
 | 
			
		||||
        case OptionType::DoNotCloseSocket:
 | 
			
		||||
            do_not_close_socket = static_cast<bool>(parameters.value);
 | 
			
		||||
            break;
 | 
			
		||||
        case OptionType::GetServerCertChain:
 | 
			
		||||
            get_server_cert_chain = static_cast<bool>(parameters.value);
 | 
			
		||||
            break;
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
 | 
			
		||||
                        parameters.value);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2};
 | 
			
		||||
        rb.Push(ResultSuccess);
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class ISslContext final : public ServiceFramework<ISslContext> {
 | 
			
		||||
public:
 | 
			
		||||
    explicit ISslContext(Core::System& system_, SslVersion version)
 | 
			
		||||
        : ServiceFramework{system_, "ISslContext"}, ssl_version{version} {
 | 
			
		||||
        : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
 | 
			
		||||
          shared_data{std::make_shared<SslContextSharedData>()} {
 | 
			
		||||
        static const FunctionInfo functions[] = {
 | 
			
		||||
            {0, &ISslContext::SetOption, "SetOption"},
 | 
			
		||||
            {1, nullptr, "GetOption"},
 | 
			
		||||
            {2, &ISslContext::CreateConnection, "CreateConnection"},
 | 
			
		||||
            {3, nullptr, "GetConnectionCount"},
 | 
			
		||||
            {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
 | 
			
		||||
            {4, &ISslContext::ImportServerPki, "ImportServerPki"},
 | 
			
		||||
            {5, &ISslContext::ImportClientPki, "ImportClientPki"},
 | 
			
		||||
            {6, nullptr, "RemoveServerPki"},
 | 
			
		||||
@@ -111,6 +416,7 @@ public:
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    SslVersion ssl_version;
 | 
			
		||||
    std::shared_ptr<SslContextSharedData> shared_data;
 | 
			
		||||
 | 
			
		||||
    void SetOption(HLERequestContext& ctx) {
 | 
			
		||||
        struct Parameters {
 | 
			
		||||
@@ -130,11 +436,24 @@ private:
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void CreateConnection(HLERequestContext& ctx) {
 | 
			
		||||
        LOG_WARNING(Service_SSL, "(STUBBED) called");
 | 
			
		||||
        LOG_WARNING(Service_SSL, "called");
 | 
			
		||||
 | 
			
		||||
        auto backend_res = CreateSSLConnectionBackend();
 | 
			
		||||
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 2, 0, 1};
 | 
			
		||||
        rb.Push(backend_res.Code());
 | 
			
		||||
        if (backend_res.Succeeded()) {
 | 
			
		||||
            rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
 | 
			
		||||
                                                std::move(*backend_res));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void GetConnectionCount(HLERequestContext& ctx) {
 | 
			
		||||
        LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
 | 
			
		||||
 | 
			
		||||
        IPC::ResponseBuilder rb{ctx, 3};
 | 
			
		||||
        rb.Push(ResultSuccess);
 | 
			
		||||
        rb.PushIpcInterface<ISslConnection>(system, ssl_version);
 | 
			
		||||
        rb.Push(shared_data->connection_count);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void ImportServerPki(HLERequestContext& ctx) {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										45
									
								
								src/core/hle/service/ssl/ssl_backend.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								src/core/hle/service/ssl/ssl_backend.h
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "core/hle/result.h"
 | 
			
		||||
 | 
			
		||||
#include "common/common_types.h"
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <span>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace Network {
 | 
			
		||||
class SocketBase;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
 | 
			
		||||
constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
 | 
			
		||||
constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
 | 
			
		||||
constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
 | 
			
		||||
 | 
			
		||||
// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
 | 
			
		||||
// with no way in the latter case to distinguish whether the client should poll
 | 
			
		||||
// for read or write.  The one official client I've seen handles this by always
 | 
			
		||||
// polling for read (with a timeout).
 | 
			
		||||
constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    virtual ~SSLConnectionBackend() {}
 | 
			
		||||
    virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
 | 
			
		||||
    virtual Result SetHostName(const std::string& hostname) = 0;
 | 
			
		||||
    virtual Result DoHandshake() = 0;
 | 
			
		||||
    virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
 | 
			
		||||
    virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
 | 
			
		||||
    virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
							
								
								
									
										16
									
								
								src/core/hle/service/ssl/ssl_backend_none.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								src/core/hle/service/ssl/ssl_backend_none.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
 | 
			
		||||
#include "common/logging/log.h"
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    LOG_ERROR(Service_SSL,
 | 
			
		||||
              "Can't create SSL connection because no SSL backend is available on this platform");
 | 
			
		||||
    return ResultInternalError;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
							
								
								
									
										351
									
								
								src/core/hle/service/ssl/ssl_backend_openssl.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										351
									
								
								src/core/hle/service/ssl/ssl_backend_openssl.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,351 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
 | 
			
		||||
#include "common/fs/file.h"
 | 
			
		||||
#include "common/hex_util.h"
 | 
			
		||||
#include "common/string_util.h"
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
#include <openssl/bio.h>
 | 
			
		||||
#include <openssl/err.h>
 | 
			
		||||
#include <openssl/ssl.h>
 | 
			
		||||
#include <openssl/x509.h>
 | 
			
		||||
 | 
			
		||||
using namespace Common::FS;
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
// Import OpenSSL's `SSL` type into the namespace.  This is needed because the
 | 
			
		||||
// namespace is also named `SSL`.
 | 
			
		||||
using ::SSL;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
std::once_flag one_time_init_flag;
 | 
			
		||||
bool one_time_init_success = false;
 | 
			
		||||
 | 
			
		||||
SSL_CTX* ssl_ctx;
 | 
			
		||||
IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
 | 
			
		||||
BIO_METHOD* bio_meth;
 | 
			
		||||
 | 
			
		||||
Result CheckOpenSSLErrors();
 | 
			
		||||
void OneTimeInit();
 | 
			
		||||
void OneTimeInitLogFile();
 | 
			
		||||
bool OneTimeInitBIO();
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    Result Init() {
 | 
			
		||||
        std::call_once(one_time_init_flag, OneTimeInit);
 | 
			
		||||
 | 
			
		||||
        if (!one_time_init_success) {
 | 
			
		||||
            LOG_ERROR(Service_SSL,
 | 
			
		||||
                      "Can't create SSL connection because OpenSSL one-time initialization failed");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        ssl = SSL_new(ssl_ctx);
 | 
			
		||||
        if (!ssl) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_new failed");
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        SSL_set_connect_state(ssl);
 | 
			
		||||
 | 
			
		||||
        bio = BIO_new(bio_meth);
 | 
			
		||||
        if (!bio) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "BIO_new failed");
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        BIO_set_data(bio, this);
 | 
			
		||||
        BIO_set_init(bio, 1);
 | 
			
		||||
        SSL_set_bio(ssl, bio, bio);
 | 
			
		||||
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
 | 
			
		||||
        socket = std::move(socket_in);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetHostName(const std::string& hostname) override {
 | 
			
		||||
        if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
        if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result DoHandshake() override {
 | 
			
		||||
        SSL_set_verify_result(ssl, X509_V_OK);
 | 
			
		||||
        const int ret = SSL_do_handshake(ssl);
 | 
			
		||||
        const long verify_result = SSL_get_verify_result(ssl);
 | 
			
		||||
        if (verify_result != X509_V_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
 | 
			
		||||
                      X509_verify_cert_error_string(verify_result));
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
        if (ret <= 0) {
 | 
			
		||||
            const int ssl_err = SSL_get_error(ssl, ret);
 | 
			
		||||
            if (ssl_err == SSL_ERROR_ZERO_RETURN ||
 | 
			
		||||
                (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) {
 | 
			
		||||
                LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return HandleReturn("SSL_do_handshake", 0, ret).Code();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Read(std::span<u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
 | 
			
		||||
        return HandleReturn("SSL_read_ex", actual, ret);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Write(std::span<const u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
 | 
			
		||||
        return HandleReturn("SSL_write_ex", actual, ret);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
 | 
			
		||||
        const int ssl_err = SSL_get_error(ssl, ret);
 | 
			
		||||
        CheckOpenSSLErrors();
 | 
			
		||||
        switch (ssl_err) {
 | 
			
		||||
        case SSL_ERROR_NONE:
 | 
			
		||||
            return actual;
 | 
			
		||||
        case SSL_ERROR_ZERO_RETURN:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
 | 
			
		||||
            // DoHandshake special-cases this, but for Read and Write:
 | 
			
		||||
            return size_t(0);
 | 
			
		||||
        case SSL_ERROR_WANT_READ:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        case SSL_ERROR_WANT_WRITE:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        default:
 | 
			
		||||
            if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
 | 
			
		||||
                LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
 | 
			
		||||
                return size_t(0);
 | 
			
		||||
            }
 | 
			
		||||
            LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
 | 
			
		||||
        STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
 | 
			
		||||
        if (!chain) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        std::vector<std::vector<u8>> ret;
 | 
			
		||||
        int count = sk_X509_num(chain);
 | 
			
		||||
        ASSERT(count >= 0);
 | 
			
		||||
        for (int i = 0; i < count; i++) {
 | 
			
		||||
            X509* x509 = sk_X509_value(chain, i);
 | 
			
		||||
            ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
 | 
			
		||||
            unsigned char* buf = nullptr;
 | 
			
		||||
            int len = i2d_X509(x509, &buf);
 | 
			
		||||
            ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
 | 
			
		||||
            ret.emplace_back(buf, buf + len);
 | 
			
		||||
            OPENSSL_free(buf);
 | 
			
		||||
        }
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~SSLConnectionBackendOpenSSL() {
 | 
			
		||||
        // these are null-tolerant:
 | 
			
		||||
        SSL_free(ssl);
 | 
			
		||||
        BIO_free(bio);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static void KeyLogCallback(const SSL* ssl, const char* line) {
 | 
			
		||||
        std::string str(line);
 | 
			
		||||
        str.push_back('\n');
 | 
			
		||||
        // Do this in a single WriteString for atomicity if multiple instances
 | 
			
		||||
        // are running on different threads (though that can't currently
 | 
			
		||||
        // happen).
 | 
			
		||||
        if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
 | 
			
		||||
            LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
 | 
			
		||||
        }
 | 
			
		||||
        LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
 | 
			
		||||
        auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            self->socket, { return 0; }, "OpenSSL asked to send but we have no socket");
 | 
			
		||||
        BIO_clear_retry_flags(bio);
 | 
			
		||||
        auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0);
 | 
			
		||||
        switch (err) {
 | 
			
		||||
        case Network::Errno::SUCCESS:
 | 
			
		||||
            *actual_p = actual;
 | 
			
		||||
            return 1;
 | 
			
		||||
        case Network::Errno::AGAIN:
 | 
			
		||||
            BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
 | 
			
		||||
            return 0;
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
 | 
			
		||||
            return -1;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
 | 
			
		||||
        auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket");
 | 
			
		||||
        BIO_clear_retry_flags(bio);
 | 
			
		||||
        auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len});
 | 
			
		||||
        switch (err) {
 | 
			
		||||
        case Network::Errno::SUCCESS:
 | 
			
		||||
            *actual_p = actual;
 | 
			
		||||
            if (actual == 0) {
 | 
			
		||||
                self->got_read_eof = true;
 | 
			
		||||
            }
 | 
			
		||||
            return actual ? 1 : 0;
 | 
			
		||||
        case Network::Errno::AGAIN:
 | 
			
		||||
            BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
 | 
			
		||||
            return 0;
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
 | 
			
		||||
            return -1;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) {
 | 
			
		||||
        switch (cmd) {
 | 
			
		||||
        case BIO_CTRL_FLUSH:
 | 
			
		||||
            // Nothing to flush.
 | 
			
		||||
            return 1;
 | 
			
		||||
        case BIO_CTRL_PUSH:
 | 
			
		||||
        case BIO_CTRL_POP:
 | 
			
		||||
#ifdef BIO_CTRL_GET_KTLS_SEND
 | 
			
		||||
        case BIO_CTRL_GET_KTLS_SEND:
 | 
			
		||||
        case BIO_CTRL_GET_KTLS_RECV:
 | 
			
		||||
#endif
 | 
			
		||||
            // We don't support these operations, but don't bother logging them
 | 
			
		||||
            // as they're nothing unusual.
 | 
			
		||||
            return 0;
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg);
 | 
			
		||||
            return 0;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SSL* ssl = nullptr;
 | 
			
		||||
    BIO* bio = nullptr;
 | 
			
		||||
    bool got_read_eof = false;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
 | 
			
		||||
    const Result res = conn->Init();
 | 
			
		||||
    if (res.IsFailure()) {
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
    return conn;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
Result CheckOpenSSLErrors() {
 | 
			
		||||
    unsigned long rc;
 | 
			
		||||
    const char* file;
 | 
			
		||||
    int line;
 | 
			
		||||
    const char* func;
 | 
			
		||||
    const char* data;
 | 
			
		||||
    int flags;
 | 
			
		||||
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
 | 
			
		||||
    while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags)))
 | 
			
		||||
#else
 | 
			
		||||
    // Can't get function names from OpenSSL on this version, so use mine:
 | 
			
		||||
    func = __func__;
 | 
			
		||||
    while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags)))
 | 
			
		||||
#endif
 | 
			
		||||
    {
 | 
			
		||||
        std::string msg;
 | 
			
		||||
        msg.resize(1024, '\0');
 | 
			
		||||
        ERR_error_string_n(rc, msg.data(), msg.size());
 | 
			
		||||
        msg.resize(strlen(msg.data()), '\0');
 | 
			
		||||
        if (flags & ERR_TXT_STRING) {
 | 
			
		||||
            msg.append(" | ");
 | 
			
		||||
            msg.append(data);
 | 
			
		||||
        }
 | 
			
		||||
        Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
 | 
			
		||||
                                   Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
 | 
			
		||||
                                   msg);
 | 
			
		||||
    }
 | 
			
		||||
    return ResultInternalError;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void OneTimeInit() {
 | 
			
		||||
    ssl_ctx = SSL_CTX_new(TLS_client_method());
 | 
			
		||||
    if (!ssl_ctx) {
 | 
			
		||||
        LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
 | 
			
		||||
        CheckOpenSSLErrors();
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
 | 
			
		||||
 | 
			
		||||
    if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
 | 
			
		||||
        LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
 | 
			
		||||
        CheckOpenSSLErrors();
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    OneTimeInitLogFile();
 | 
			
		||||
 | 
			
		||||
    if (!OneTimeInitBIO()) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    one_time_init_success = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void OneTimeInitLogFile() {
 | 
			
		||||
    const char* logfile = getenv("SSLKEYLOGFILE");
 | 
			
		||||
    if (logfile) {
 | 
			
		||||
        key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
 | 
			
		||||
                          FileShareFlag::ShareWriteOnly);
 | 
			
		||||
        if (key_log_file.IsOpen()) {
 | 
			
		||||
            SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
 | 
			
		||||
        } else {
 | 
			
		||||
            LOG_CRITICAL(Service_SSL,
 | 
			
		||||
                         "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool OneTimeInitBIO() {
 | 
			
		||||
    bio_meth =
 | 
			
		||||
        BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
 | 
			
		||||
    if (!bio_meth ||
 | 
			
		||||
        !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
 | 
			
		||||
        !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
 | 
			
		||||
        !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
 | 
			
		||||
        LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
							
								
								
									
										543
									
								
								src/core/hle/service/ssl/ssl_backend_schannel.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										543
									
								
								src/core/hle/service/ssl/ssl_backend_schannel.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,543 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
 | 
			
		||||
#include "common/error.h"
 | 
			
		||||
#include "common/fs/file.h"
 | 
			
		||||
#include "common/hex_util.h"
 | 
			
		||||
#include "common/string_util.h"
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// These includes are inside the namespace to avoid a conflict on MinGW where
 | 
			
		||||
// the headers define an enum containing Network and Service as enumerators
 | 
			
		||||
// (which clash with the correspondingly named namespaces).
 | 
			
		||||
#define SECURITY_WIN32
 | 
			
		||||
#include <schnlsp.h>
 | 
			
		||||
#include <security.h>
 | 
			
		||||
 | 
			
		||||
std::once_flag one_time_init_flag;
 | 
			
		||||
bool one_time_init_success = false;
 | 
			
		||||
 | 
			
		||||
SCHANNEL_CRED schannel_cred{};
 | 
			
		||||
CredHandle cred_handle;
 | 
			
		||||
 | 
			
		||||
static void OneTimeInit() {
 | 
			
		||||
    schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
 | 
			
		||||
    schannel_cred.dwFlags =
 | 
			
		||||
        SCH_USE_STRONG_CRYPTO |         // don't allow insecure protocols
 | 
			
		||||
        SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
 | 
			
		||||
        SCH_CRED_NO_DEFAULT_CREDS;      // don't automatically present a client certificate
 | 
			
		||||
    // ^ I'm assuming that nobody would want to connect Yuzu to a
 | 
			
		||||
    // service that requires some OS-provided corporate client
 | 
			
		||||
    // certificate, and presenting one to some arbitrary server
 | 
			
		||||
    // might be a privacy concern?  Who knows, though.
 | 
			
		||||
 | 
			
		||||
    const SECURITY_STATUS ret =
 | 
			
		||||
        AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
 | 
			
		||||
                                 nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
 | 
			
		||||
    if (ret != SEC_E_OK) {
 | 
			
		||||
        // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
 | 
			
		||||
        LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
 | 
			
		||||
                  Common::NativeErrorToString(ret));
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (getenv("SSLKEYLOGFILE")) {
 | 
			
		||||
        LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
 | 
			
		||||
                                  "keys; not logging keys!");
 | 
			
		||||
        // Not fatal.
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    one_time_init_success = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    Result Init() {
 | 
			
		||||
        std::call_once(one_time_init_flag, OneTimeInit);
 | 
			
		||||
 | 
			
		||||
        if (!one_time_init_success) {
 | 
			
		||||
            LOG_ERROR(
 | 
			
		||||
                Service_SSL,
 | 
			
		||||
                "Can't create SSL connection because Schannel one-time initialization failed");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
 | 
			
		||||
        socket = std::move(socket_in);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetHostName(const std::string& hostname_in) override {
 | 
			
		||||
        hostname = hostname_in;
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result DoHandshake() override {
 | 
			
		||||
        while (1) {
 | 
			
		||||
            Result r;
 | 
			
		||||
            switch (handshake_state) {
 | 
			
		||||
            case HandshakeState::Initial:
 | 
			
		||||
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
 | 
			
		||||
                    (r = CallInitializeSecurityContext()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                // CallInitializeSecurityContext updated `handshake_state`.
 | 
			
		||||
                continue;
 | 
			
		||||
            case HandshakeState::ContinueNeeded:
 | 
			
		||||
            case HandshakeState::IncompleteMessage:
 | 
			
		||||
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
 | 
			
		||||
                    (r = FillCiphertextReadBuf()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                if (ciphertext_read_buf.empty()) {
 | 
			
		||||
                    LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
 | 
			
		||||
                    return ResultInternalError;
 | 
			
		||||
                }
 | 
			
		||||
                if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                // CallInitializeSecurityContext updated `handshake_state`.
 | 
			
		||||
                continue;
 | 
			
		||||
            case HandshakeState::DoneAfterFlush:
 | 
			
		||||
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                handshake_state = HandshakeState::Connected;
 | 
			
		||||
                return ResultSuccess;
 | 
			
		||||
            case HandshakeState::Connected:
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            case HandshakeState::Error:
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result FillCiphertextReadBuf() {
 | 
			
		||||
        const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
 | 
			
		||||
        read_buf_fill_size = 0;
 | 
			
		||||
        // This unnecessarily zeroes the buffer; oh well.
 | 
			
		||||
        const size_t offset = ciphertext_read_buf.size();
 | 
			
		||||
        ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
 | 
			
		||||
        ciphertext_read_buf.resize(offset + fill_size, 0);
 | 
			
		||||
        const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
 | 
			
		||||
        const auto [actual, err] = socket->Recv(0, read_span);
 | 
			
		||||
        switch (err) {
 | 
			
		||||
        case Network::Errno::SUCCESS:
 | 
			
		||||
            ASSERT(static_cast<size_t>(actual) <= fill_size);
 | 
			
		||||
            ciphertext_read_buf.resize(offset + actual);
 | 
			
		||||
            return ResultSuccess;
 | 
			
		||||
        case Network::Errno::AGAIN:
 | 
			
		||||
            ciphertext_read_buf.resize(offset);
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        default:
 | 
			
		||||
            ciphertext_read_buf.resize(offset);
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Returns success if the write buffer has been completely emptied.
 | 
			
		||||
    Result FlushCiphertextWriteBuf() {
 | 
			
		||||
        while (!ciphertext_write_buf.empty()) {
 | 
			
		||||
            const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
 | 
			
		||||
            switch (err) {
 | 
			
		||||
            case Network::Errno::SUCCESS:
 | 
			
		||||
                ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
 | 
			
		||||
                ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
 | 
			
		||||
                                           ciphertext_write_buf.begin() + actual);
 | 
			
		||||
                break;
 | 
			
		||||
            case Network::Errno::AGAIN:
 | 
			
		||||
                return ResultWouldBlock;
 | 
			
		||||
            default:
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result CallInitializeSecurityContext() {
 | 
			
		||||
        const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
 | 
			
		||||
                                  ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
 | 
			
		||||
                                  ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
 | 
			
		||||
                                  ISC_REQ_USE_SUPPLIED_CREDS;
 | 
			
		||||
        unsigned long attr;
 | 
			
		||||
        // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
 | 
			
		||||
        std::array<SecBuffer, 2> input_buffers{{
 | 
			
		||||
            // only used if `initial_call_done`
 | 
			
		||||
            {
 | 
			
		||||
                // [0]
 | 
			
		||||
                .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
 | 
			
		||||
                .BufferType = SECBUFFER_TOKEN,
 | 
			
		||||
                .pvBuffer = ciphertext_read_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
 | 
			
		||||
                //     returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
 | 
			
		||||
                //     whole buffer wasn't used)
 | 
			
		||||
                .cbBuffer = 0,
 | 
			
		||||
                .BufferType = SECBUFFER_EMPTY,
 | 
			
		||||
                .pvBuffer = nullptr,
 | 
			
		||||
            },
 | 
			
		||||
        }};
 | 
			
		||||
        std::array<SecBuffer, 2> output_buffers{{
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = 0,
 | 
			
		||||
                .BufferType = SECBUFFER_TOKEN,
 | 
			
		||||
                .pvBuffer = nullptr,
 | 
			
		||||
            }, // [0]
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = 0,
 | 
			
		||||
                .BufferType = SECBUFFER_ALERT,
 | 
			
		||||
                .pvBuffer = nullptr,
 | 
			
		||||
            }, // [1]
 | 
			
		||||
        }};
 | 
			
		||||
        SecBufferDesc input_desc{
 | 
			
		||||
            .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
            .cBuffers = static_cast<unsigned long>(input_buffers.size()),
 | 
			
		||||
            .pBuffers = input_buffers.data(),
 | 
			
		||||
        };
 | 
			
		||||
        SecBufferDesc output_desc{
 | 
			
		||||
            .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
            .cBuffers = static_cast<unsigned long>(output_buffers.size()),
 | 
			
		||||
            .pBuffers = output_buffers.data(),
 | 
			
		||||
        };
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
 | 
			
		||||
            { return ResultInternalError; }, "read buffer too large");
 | 
			
		||||
 | 
			
		||||
        bool initial_call_done = handshake_state != HandshakeState::Initial;
 | 
			
		||||
        if (initial_call_done) {
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
 | 
			
		||||
                      ciphertext_read_buf.size());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const SECURITY_STATUS ret =
 | 
			
		||||
            InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
 | 
			
		||||
                                       // Caller ensured we have set a hostname:
 | 
			
		||||
                                       const_cast<char*>(hostname.value().c_str()), req,
 | 
			
		||||
                                       0, // Reserved1
 | 
			
		||||
                                       0, // TargetDataRep not used with Schannel
 | 
			
		||||
                                       initial_call_done ? &input_desc : nullptr,
 | 
			
		||||
                                       0, // Reserved2
 | 
			
		||||
                                       initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
 | 
			
		||||
                                       nullptr); // ptsExpiry
 | 
			
		||||
 | 
			
		||||
        if (output_buffers[0].pvBuffer) {
 | 
			
		||||
            const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
 | 
			
		||||
                                 output_buffers[0].cbBuffer);
 | 
			
		||||
            ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
 | 
			
		||||
            FreeContextBuffer(output_buffers[0].pvBuffer);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (output_buffers[1].pvBuffer) {
 | 
			
		||||
            const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
 | 
			
		||||
                                 output_buffers[1].cbBuffer);
 | 
			
		||||
            // The documentation doesn't explain what format this data is in.
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
 | 
			
		||||
                      Common::HexToString(span));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        switch (ret) {
 | 
			
		||||
        case SEC_I_CONTINUE_NEEDED:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
 | 
			
		||||
            if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
 | 
			
		||||
                LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
 | 
			
		||||
                ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
 | 
			
		||||
                ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
 | 
			
		||||
                                          ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
 | 
			
		||||
            } else {
 | 
			
		||||
                ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
 | 
			
		||||
                ciphertext_read_buf.clear();
 | 
			
		||||
            }
 | 
			
		||||
            handshake_state = HandshakeState::ContinueNeeded;
 | 
			
		||||
            return ResultSuccess;
 | 
			
		||||
        case SEC_E_INCOMPLETE_MESSAGE:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
 | 
			
		||||
            ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
 | 
			
		||||
            read_buf_fill_size = input_buffers[1].cbBuffer;
 | 
			
		||||
            handshake_state = HandshakeState::IncompleteMessage;
 | 
			
		||||
            return ResultSuccess;
 | 
			
		||||
        case SEC_E_OK:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
 | 
			
		||||
            ciphertext_read_buf.clear();
 | 
			
		||||
            handshake_state = HandshakeState::DoneAfterFlush;
 | 
			
		||||
            return GrabStreamSizes();
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_ERROR(Service_SSL,
 | 
			
		||||
                      "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
 | 
			
		||||
                      Common::NativeErrorToString(ret));
 | 
			
		||||
            handshake_state = HandshakeState::Error;
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result GrabStreamSizes() {
 | 
			
		||||
        const SECURITY_STATUS ret =
 | 
			
		||||
            QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
 | 
			
		||||
        if (ret != SEC_E_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
 | 
			
		||||
                      Common::NativeErrorToString(ret));
 | 
			
		||||
            handshake_state = HandshakeState::Error;
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Read(std::span<u8> data) override {
 | 
			
		||||
        if (handshake_state != HandshakeState::Connected) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        if (data.size() == 0 || got_read_eof) {
 | 
			
		||||
            return size_t(0);
 | 
			
		||||
        }
 | 
			
		||||
        while (1) {
 | 
			
		||||
            if (!cleartext_read_buf.empty()) {
 | 
			
		||||
                const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
 | 
			
		||||
                std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
 | 
			
		||||
                cleartext_read_buf.erase(cleartext_read_buf.begin(),
 | 
			
		||||
                                         cleartext_read_buf.begin() + read_size);
 | 
			
		||||
                return read_size;
 | 
			
		||||
            }
 | 
			
		||||
            if (!ciphertext_read_buf.empty()) {
 | 
			
		||||
                SecBuffer empty{
 | 
			
		||||
                    .cbBuffer = 0,
 | 
			
		||||
                    .BufferType = SECBUFFER_EMPTY,
 | 
			
		||||
                    .pvBuffer = nullptr,
 | 
			
		||||
                };
 | 
			
		||||
                std::array<SecBuffer, 5> buffers{{
 | 
			
		||||
                    {
 | 
			
		||||
                        .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
 | 
			
		||||
                        .BufferType = SECBUFFER_DATA,
 | 
			
		||||
                        .pvBuffer = ciphertext_read_buf.data(),
 | 
			
		||||
                    },
 | 
			
		||||
                    empty,
 | 
			
		||||
                    empty,
 | 
			
		||||
                    empty,
 | 
			
		||||
                }};
 | 
			
		||||
                ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
                    buffers[0].cbBuffer == ciphertext_read_buf.size(),
 | 
			
		||||
                    { return ResultInternalError; }, "read buffer too large");
 | 
			
		||||
                SecBufferDesc desc{
 | 
			
		||||
                    .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
                    .cBuffers = static_cast<unsigned long>(buffers.size()),
 | 
			
		||||
                    .pBuffers = buffers.data(),
 | 
			
		||||
                };
 | 
			
		||||
                SECURITY_STATUS ret =
 | 
			
		||||
                    DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
 | 
			
		||||
                switch (ret) {
 | 
			
		||||
                case SEC_E_OK:
 | 
			
		||||
                    ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
 | 
			
		||||
                                      { return ResultInternalError; });
 | 
			
		||||
                    ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
 | 
			
		||||
                                      { return ResultInternalError; });
 | 
			
		||||
                    ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
 | 
			
		||||
                                      { return ResultInternalError; });
 | 
			
		||||
                    cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
 | 
			
		||||
                                              static_cast<u8*>(buffers[1].pvBuffer) +
 | 
			
		||||
                                                  buffers[1].cbBuffer);
 | 
			
		||||
                    if (buffers[3].BufferType == SECBUFFER_EXTRA) {
 | 
			
		||||
                        ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
 | 
			
		||||
                        ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
 | 
			
		||||
                                                  ciphertext_read_buf.end() - buffers[3].cbBuffer);
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
 | 
			
		||||
                        ciphertext_read_buf.clear();
 | 
			
		||||
                    }
 | 
			
		||||
                    continue;
 | 
			
		||||
                case SEC_E_INCOMPLETE_MESSAGE:
 | 
			
		||||
                    break;
 | 
			
		||||
                case SEC_I_CONTEXT_EXPIRED:
 | 
			
		||||
                    // Server hung up by sending close_notify.
 | 
			
		||||
                    got_read_eof = true;
 | 
			
		||||
                    return size_t(0);
 | 
			
		||||
                default:
 | 
			
		||||
                    LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
 | 
			
		||||
                              Common::NativeErrorToString(ret));
 | 
			
		||||
                    return ResultInternalError;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            const Result r = FillCiphertextReadBuf();
 | 
			
		||||
            if (r != ResultSuccess) {
 | 
			
		||||
                return r;
 | 
			
		||||
            }
 | 
			
		||||
            if (ciphertext_read_buf.empty()) {
 | 
			
		||||
                got_read_eof = true;
 | 
			
		||||
                return size_t(0);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Write(std::span<const u8> data) override {
 | 
			
		||||
        if (handshake_state != HandshakeState::Connected) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        if (data.size() == 0) {
 | 
			
		||||
            return size_t(0);
 | 
			
		||||
        }
 | 
			
		||||
        data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
 | 
			
		||||
        if (!cleartext_write_buf.empty()) {
 | 
			
		||||
            // Already in the middle of a write.  It wouldn't make sense to not
 | 
			
		||||
            // finish sending the entire buffer since TLS has
 | 
			
		||||
            // header/MAC/padding/etc.
 | 
			
		||||
            if (data.size() != cleartext_write_buf.size() ||
 | 
			
		||||
                std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
            return WriteAlreadyEncryptedData();
 | 
			
		||||
        } else {
 | 
			
		||||
            cleartext_write_buf.assign(data.begin(), data.end());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
 | 
			
		||||
        std::vector<u8> tmp_data_buf = cleartext_write_buf;
 | 
			
		||||
        std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);
 | 
			
		||||
 | 
			
		||||
        std::array<SecBuffer, 3> buffers{{
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = stream_sizes.cbHeader,
 | 
			
		||||
                .BufferType = SECBUFFER_STREAM_HEADER,
 | 
			
		||||
                .pvBuffer = header_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
 | 
			
		||||
                .BufferType = SECBUFFER_DATA,
 | 
			
		||||
                .pvBuffer = tmp_data_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = stream_sizes.cbTrailer,
 | 
			
		||||
                .BufferType = SECBUFFER_STREAM_TRAILER,
 | 
			
		||||
                .pvBuffer = trailer_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
        }};
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
 | 
			
		||||
            "temp buffer too large");
 | 
			
		||||
        SecBufferDesc desc{
 | 
			
		||||
            .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
            .cBuffers = static_cast<unsigned long>(buffers.size()),
 | 
			
		||||
            .pBuffers = buffers.data(),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
 | 
			
		||||
        if (ret != SEC_E_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
 | 
			
		||||
                                    header_buf.end());
 | 
			
		||||
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
 | 
			
		||||
                                    tmp_data_buf.end());
 | 
			
		||||
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
 | 
			
		||||
                                    trailer_buf.end());
 | 
			
		||||
        return WriteAlreadyEncryptedData();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> WriteAlreadyEncryptedData() {
 | 
			
		||||
        const Result r = FlushCiphertextWriteBuf();
 | 
			
		||||
        if (r != ResultSuccess) {
 | 
			
		||||
            return r;
 | 
			
		||||
        }
 | 
			
		||||
        // write buf is empty
 | 
			
		||||
        const size_t cleartext_bytes_written = cleartext_write_buf.size();
 | 
			
		||||
        cleartext_write_buf.clear();
 | 
			
		||||
        return cleartext_bytes_written;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
 | 
			
		||||
        PCCERT_CONTEXT returned_cert = nullptr;
 | 
			
		||||
        const SECURITY_STATUS ret =
 | 
			
		||||
            QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
 | 
			
		||||
        if (ret != SEC_E_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL,
 | 
			
		||||
                      "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
 | 
			
		||||
                      Common::NativeErrorToString(ret));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        PCCERT_CONTEXT some_cert = nullptr;
 | 
			
		||||
        std::vector<std::vector<u8>> certs;
 | 
			
		||||
        while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
 | 
			
		||||
            certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
 | 
			
		||||
                               static_cast<u8*>(some_cert->pbCertEncoded) +
 | 
			
		||||
                                   some_cert->cbCertEncoded);
 | 
			
		||||
        }
 | 
			
		||||
        std::reverse(certs.begin(),
 | 
			
		||||
                     certs.end()); // Windows returns certs in reverse order from what we want
 | 
			
		||||
        CertFreeCertificateContext(returned_cert);
 | 
			
		||||
        return certs;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~SSLConnectionBackendSchannel() {
 | 
			
		||||
        if (handshake_state != HandshakeState::Initial) {
 | 
			
		||||
            DeleteSecurityContext(&ctxt);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    enum class HandshakeState {
 | 
			
		||||
        // Haven't called anything yet.
 | 
			
		||||
        Initial,
 | 
			
		||||
        // `SEC_I_CONTINUE_NEEDED` was returned by
 | 
			
		||||
        // `InitializeSecurityContext`; must finish sending data (if any) in
 | 
			
		||||
        // the write buffer, then read at least one byte before calling
 | 
			
		||||
        // `InitializeSecurityContext` again.
 | 
			
		||||
        ContinueNeeded,
 | 
			
		||||
        // `SEC_E_INCOMPLETE_MESSAGE` was returned by
 | 
			
		||||
        // `InitializeSecurityContext`; hopefully the write buffer is empty;
 | 
			
		||||
        // must read at least one byte before calling
 | 
			
		||||
        // `InitializeSecurityContext` again.
 | 
			
		||||
        IncompleteMessage,
 | 
			
		||||
        // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
 | 
			
		||||
        // finish sending data in the write buffer before having `DoHandshake`
 | 
			
		||||
        // report success.
 | 
			
		||||
        DoneAfterFlush,
 | 
			
		||||
        // We finished the above and are now connected.  At this point, writing
 | 
			
		||||
        // and reading are separate 'state machines' represented by the
 | 
			
		||||
        // nonemptiness of the ciphertext and cleartext read and write buffers.
 | 
			
		||||
        Connected,
 | 
			
		||||
        // Another error was returned and we shouldn't allow initialization
 | 
			
		||||
        // to continue.
 | 
			
		||||
        Error,
 | 
			
		||||
    } handshake_state = HandshakeState::Initial;
 | 
			
		||||
 | 
			
		||||
    CtxtHandle ctxt;
 | 
			
		||||
    SecPkgContext_StreamSizes stream_sizes;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
    std::optional<std::string> hostname;
 | 
			
		||||
 | 
			
		||||
    std::vector<u8> ciphertext_read_buf;
 | 
			
		||||
    std::vector<u8> ciphertext_write_buf;
 | 
			
		||||
    std::vector<u8> cleartext_read_buf;
 | 
			
		||||
    std::vector<u8> cleartext_write_buf;
 | 
			
		||||
 | 
			
		||||
    bool got_read_eof = false;
 | 
			
		||||
    size_t read_buf_fill_size = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    auto conn = std::make_unique<SSLConnectionBackendSchannel>();
 | 
			
		||||
    const Result res = conn->Init();
 | 
			
		||||
    if (res.IsFailure()) {
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
    return conn;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
							
								
								
									
										219
									
								
								src/core/hle/service/ssl/ssl_backend_securetransport.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										219
									
								
								src/core/hle/service/ssl/ssl_backend_securetransport.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,219 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
#include <Security/SecureTransport.h>
 | 
			
		||||
 | 
			
		||||
// SecureTransport has been deprecated in its entirety in favor of
 | 
			
		||||
// Network.framework, but that does not allow layering TLS on top of an
 | 
			
		||||
// arbitrary socket.
 | 
			
		||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct CFReleaser {
 | 
			
		||||
    T ptr;
 | 
			
		||||
 | 
			
		||||
    YUZU_NON_COPYABLE(CFReleaser);
 | 
			
		||||
    constexpr CFReleaser() : ptr(nullptr) {}
 | 
			
		||||
    constexpr CFReleaser(T ptr) : ptr(ptr) {}
 | 
			
		||||
    constexpr operator T() {
 | 
			
		||||
        return ptr;
 | 
			
		||||
    }
 | 
			
		||||
    ~CFReleaser() {
 | 
			
		||||
        if (ptr) {
 | 
			
		||||
            CFRelease(ptr);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::string CFStringToString(CFStringRef cfstr) {
 | 
			
		||||
    CFReleaser<CFDataRef> cfdata(
 | 
			
		||||
        CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0));
 | 
			
		||||
    ASSERT_OR_EXECUTE(cfdata, { return "???"; });
 | 
			
		||||
    return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)),
 | 
			
		||||
                       CFDataGetLength(cfdata));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string OSStatusToString(OSStatus status) {
 | 
			
		||||
    CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr));
 | 
			
		||||
    if (!cfstr) {
 | 
			
		||||
        return "[unknown error]";
 | 
			
		||||
    }
 | 
			
		||||
    return CFStringToString(cfstr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    Result Init() {
 | 
			
		||||
        static std::once_flag once_flag;
 | 
			
		||||
        std::call_once(once_flag, []() {
 | 
			
		||||
            if (getenv("SSLKEYLOGFILE")) {
 | 
			
		||||
                LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not "
 | 
			
		||||
                                          "support exporting keys; not logging keys!");
 | 
			
		||||
                // Not fatal.
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType);
 | 
			
		||||
        if (!context) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLCreateContext failed");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        OSStatus status;
 | 
			
		||||
        if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) ||
 | 
			
		||||
            (status = SSLSetConnection(context, this))) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}",
 | 
			
		||||
                      OSStatusToString(status));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override {
 | 
			
		||||
        socket = std::move(in_socket);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetHostName(const std::string& hostname) override {
 | 
			
		||||
        OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size());
 | 
			
		||||
        if (status) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result DoHandshake() override {
 | 
			
		||||
        OSStatus status = SSLHandshake(context);
 | 
			
		||||
        return HandleReturn("SSLHandshake", 0, status).Code();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Read(std::span<u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        OSStatus status = SSLRead(context, data.data(), data.size(), &actual);
 | 
			
		||||
        ;
 | 
			
		||||
        return HandleReturn("SSLRead", actual, status);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Write(std::span<const u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        OSStatus status = SSLWrite(context, data.data(), data.size(), &actual);
 | 
			
		||||
        ;
 | 
			
		||||
        return HandleReturn("SSLWrite", actual, status);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) {
 | 
			
		||||
        switch (status) {
 | 
			
		||||
        case 0:
 | 
			
		||||
            return actual;
 | 
			
		||||
        case errSSLWouldBlock:
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        default: {
 | 
			
		||||
            std::string reason;
 | 
			
		||||
            if (got_read_eof) {
 | 
			
		||||
                reason = "server hung up";
 | 
			
		||||
            } else {
 | 
			
		||||
                reason = OSStatusToString(status);
 | 
			
		||||
            }
 | 
			
		||||
            LOG_ERROR(Service_SSL, "{} failed: {}", what, reason);
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
 | 
			
		||||
        CFReleaser<SecTrustRef> trust;
 | 
			
		||||
        OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
 | 
			
		||||
        if (status) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        std::vector<std::vector<u8>> ret;
 | 
			
		||||
        for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
 | 
			
		||||
            SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
 | 
			
		||||
            CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
 | 
			
		||||
            ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
 | 
			
		||||
            const u8* ptr = CFDataGetBytePtr(data);
 | 
			
		||||
            ret.emplace_back(ptr, ptr + CFDataGetLength(data));
 | 
			
		||||
        }
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
 | 
			
		||||
        return ReadOrWriteCallback(connection, data, dataLength, true);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static OSStatus WriteCallback(SSLConnectionRef connection, const void* data,
 | 
			
		||||
                                  size_t* dataLength) {
 | 
			
		||||
        return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength,
 | 
			
		||||
                                        bool is_read) {
 | 
			
		||||
        auto self =
 | 
			
		||||
            static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection));
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket",
 | 
			
		||||
            is_read ? "read" : "write");
 | 
			
		||||
 | 
			
		||||
        // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
 | 
			
		||||
        // expected to read/write the full requested dataLength or return an
 | 
			
		||||
        // error, so we have to add a loop ourselves.
 | 
			
		||||
        size_t requested_len = *dataLength;
 | 
			
		||||
        size_t offset = 0;
 | 
			
		||||
        while (offset < requested_len) {
 | 
			
		||||
            std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset);
 | 
			
		||||
            auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0);
 | 
			
		||||
            LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset,
 | 
			
		||||
                         actual, cur.size(), static_cast<s32>(err));
 | 
			
		||||
            switch (err) {
 | 
			
		||||
            case Network::Errno::SUCCESS:
 | 
			
		||||
                offset += actual;
 | 
			
		||||
                if (actual == 0) {
 | 
			
		||||
                    ASSERT(is_read);
 | 
			
		||||
                    self->got_read_eof = true;
 | 
			
		||||
                    return errSecEndOfData;
 | 
			
		||||
                }
 | 
			
		||||
                break;
 | 
			
		||||
            case Network::Errno::AGAIN:
 | 
			
		||||
                *dataLength = offset;
 | 
			
		||||
                return errSSLWouldBlock;
 | 
			
		||||
            default:
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}",
 | 
			
		||||
                          is_read ? "recv" : "send", err);
 | 
			
		||||
                return errSecIO;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        ASSERT(offset == requested_len);
 | 
			
		||||
        return 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    CFReleaser<SSLContextRef> context = nullptr;
 | 
			
		||||
    bool got_read_eof = false;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
 | 
			
		||||
    const Result res = conn->Init();
 | 
			
		||||
    if (res.IsFailure()) {
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
    return conn;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
@@ -27,6 +27,7 @@
 | 
			
		||||
 | 
			
		||||
#include "common/assert.h"
 | 
			
		||||
#include "common/common_types.h"
 | 
			
		||||
#include "common/expected.h"
 | 
			
		||||
#include "common/logging/log.h"
 | 
			
		||||
#include "common/settings.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
@@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) {
 | 
			
		||||
 | 
			
		||||
Errno TranslateNativeError(int e) {
 | 
			
		||||
    switch (e) {
 | 
			
		||||
    case 0:
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    case WSAEBADF:
 | 
			
		||||
        return Errno::BADF;
 | 
			
		||||
    case WSAEINVAL:
 | 
			
		||||
@@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) {
 | 
			
		||||
        return Errno::MSGSIZE;
 | 
			
		||||
    case WSAETIMEDOUT:
 | 
			
		||||
        return Errno::TIMEDOUT;
 | 
			
		||||
    case WSAEINPROGRESS:
 | 
			
		||||
        return Errno::INPROGRESS;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
 | 
			
		||||
        return Errno::OTHER;
 | 
			
		||||
@@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) {
 | 
			
		||||
 | 
			
		||||
Errno TranslateNativeError(int e) {
 | 
			
		||||
    switch (e) {
 | 
			
		||||
    case 0:
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    case EBADF:
 | 
			
		||||
        return Errno::BADF;
 | 
			
		||||
    case EINVAL:
 | 
			
		||||
@@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) {
 | 
			
		||||
        return Errno::MSGSIZE;
 | 
			
		||||
    case ETIMEDOUT:
 | 
			
		||||
        return Errno::TIMEDOUT;
 | 
			
		||||
    case EINPROGRESS:
 | 
			
		||||
        return Errno::INPROGRESS;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e));
 | 
			
		||||
        return Errno::OTHER;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -234,15 +243,84 @@ Errno GetAndLogLastError() {
 | 
			
		||||
    int e = errno;
 | 
			
		||||
#endif
 | 
			
		||||
    const Errno err = TranslateNativeError(e);
 | 
			
		||||
    if (err == Errno::AGAIN || err == Errno::TIMEDOUT) {
 | 
			
		||||
    if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) {
 | 
			
		||||
        // These happen during normal operation, so only log them at debug level.
 | 
			
		||||
        LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
 | 
			
		||||
        return err;
 | 
			
		||||
    }
 | 
			
		||||
    LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
 | 
			
		||||
    return err;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TranslateDomain(Domain domain) {
 | 
			
		||||
GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) {
 | 
			
		||||
    switch (gai_err) {
 | 
			
		||||
    case 0:
 | 
			
		||||
        return GetAddrInfoError::SUCCESS;
 | 
			
		||||
#ifdef EAI_ADDRFAMILY
 | 
			
		||||
    case EAI_ADDRFAMILY:
 | 
			
		||||
        return GetAddrInfoError::ADDRFAMILY;
 | 
			
		||||
#endif
 | 
			
		||||
    case EAI_AGAIN:
 | 
			
		||||
        return GetAddrInfoError::AGAIN;
 | 
			
		||||
    case EAI_BADFLAGS:
 | 
			
		||||
        return GetAddrInfoError::BADFLAGS;
 | 
			
		||||
    case EAI_FAIL:
 | 
			
		||||
        return GetAddrInfoError::FAIL;
 | 
			
		||||
    case EAI_FAMILY:
 | 
			
		||||
        return GetAddrInfoError::FAMILY;
 | 
			
		||||
    case EAI_MEMORY:
 | 
			
		||||
        return GetAddrInfoError::MEMORY;
 | 
			
		||||
    case EAI_NONAME:
 | 
			
		||||
        return GetAddrInfoError::NONAME;
 | 
			
		||||
    case EAI_SERVICE:
 | 
			
		||||
        return GetAddrInfoError::SERVICE;
 | 
			
		||||
    case EAI_SOCKTYPE:
 | 
			
		||||
        return GetAddrInfoError::SOCKTYPE;
 | 
			
		||||
        // These codes may not be defined on all systems:
 | 
			
		||||
#ifdef EAI_SYSTEM
 | 
			
		||||
    case EAI_SYSTEM:
 | 
			
		||||
        return GetAddrInfoError::SYSTEM;
 | 
			
		||||
#endif
 | 
			
		||||
#ifdef EAI_BADHINTS
 | 
			
		||||
    case EAI_BADHINTS:
 | 
			
		||||
        return GetAddrInfoError::BADHINTS;
 | 
			
		||||
#endif
 | 
			
		||||
#ifdef EAI_PROTOCOL
 | 
			
		||||
    case EAI_PROTOCOL:
 | 
			
		||||
        return GetAddrInfoError::PROTOCOL;
 | 
			
		||||
#endif
 | 
			
		||||
#ifdef EAI_OVERFLOW
 | 
			
		||||
    case EAI_OVERFLOW:
 | 
			
		||||
        return GetAddrInfoError::OVERFLOW_;
 | 
			
		||||
#endif
 | 
			
		||||
    default:
 | 
			
		||||
#ifdef EAI_NODATA
 | 
			
		||||
        // This can't be a case statement because it would create a duplicate
 | 
			
		||||
        // case on Windows where EAI_NODATA is an alias for EAI_NONAME.
 | 
			
		||||
        if (gai_err == EAI_NODATA) {
 | 
			
		||||
            return GetAddrInfoError::NODATA;
 | 
			
		||||
        }
 | 
			
		||||
#endif
 | 
			
		||||
        return GetAddrInfoError::OTHER;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Domain TranslateDomainFromNative(int domain) {
 | 
			
		||||
    switch (domain) {
 | 
			
		||||
    case 0:
 | 
			
		||||
        return Domain::Unspecified;
 | 
			
		||||
    case AF_INET:
 | 
			
		||||
        return Domain::INET;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unhandled domain={}", domain);
 | 
			
		||||
        return Domain::INET;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TranslateDomainToNative(Domain domain) {
 | 
			
		||||
    switch (domain) {
 | 
			
		||||
    case Domain::Unspecified:
 | 
			
		||||
        return 0;
 | 
			
		||||
    case Domain::INET:
 | 
			
		||||
        return AF_INET;
 | 
			
		||||
    default:
 | 
			
		||||
@@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TranslateType(Type type) {
 | 
			
		||||
Type TranslateTypeFromNative(int type) {
 | 
			
		||||
    switch (type) {
 | 
			
		||||
    case 0:
 | 
			
		||||
        return Type::Unspecified;
 | 
			
		||||
    case SOCK_STREAM:
 | 
			
		||||
        return Type::STREAM;
 | 
			
		||||
    case SOCK_DGRAM:
 | 
			
		||||
        return Type::DGRAM;
 | 
			
		||||
    case SOCK_RAW:
 | 
			
		||||
        return Type::RAW;
 | 
			
		||||
    case SOCK_SEQPACKET:
 | 
			
		||||
        return Type::SEQPACKET;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented type={}", type);
 | 
			
		||||
        return Type::STREAM;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TranslateTypeToNative(Type type) {
 | 
			
		||||
    switch (type) {
 | 
			
		||||
    case Type::Unspecified:
 | 
			
		||||
        return 0;
 | 
			
		||||
    case Type::STREAM:
 | 
			
		||||
        return SOCK_STREAM;
 | 
			
		||||
    case Type::DGRAM:
 | 
			
		||||
        return SOCK_DGRAM;
 | 
			
		||||
    case Type::RAW:
 | 
			
		||||
        return SOCK_RAW;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented type={}", type);
 | 
			
		||||
        return 0;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TranslateProtocol(Protocol protocol) {
 | 
			
		||||
Protocol TranslateProtocolFromNative(int protocol) {
 | 
			
		||||
    switch (protocol) {
 | 
			
		||||
    case 0:
 | 
			
		||||
        return Protocol::Unspecified;
 | 
			
		||||
    case IPPROTO_TCP:
 | 
			
		||||
        return Protocol::TCP;
 | 
			
		||||
    case IPPROTO_UDP:
 | 
			
		||||
        return Protocol::UDP;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
 | 
			
		||||
        return Protocol::Unspecified;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TranslateProtocolToNative(Protocol protocol) {
 | 
			
		||||
    switch (protocol) {
 | 
			
		||||
    case Protocol::Unspecified:
 | 
			
		||||
        return 0;
 | 
			
		||||
    case Protocol::TCP:
 | 
			
		||||
        return IPPROTO_TCP;
 | 
			
		||||
    case Protocol::UDP:
 | 
			
		||||
@@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
 | 
			
		||||
    sockaddr_in input;
 | 
			
		||||
    std::memcpy(&input, &input_, sizeof(input));
 | 
			
		||||
 | 
			
		||||
SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) {
 | 
			
		||||
    SockAddrIn result;
 | 
			
		||||
 | 
			
		||||
    switch (input.sin_family) {
 | 
			
		||||
    case AF_INET:
 | 
			
		||||
        result.family = Domain::INET;
 | 
			
		||||
        break;
 | 
			
		||||
    default:
 | 
			
		||||
        UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family);
 | 
			
		||||
        result.family = Domain::INET;
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
    result.family = TranslateDomainFromNative(input.sin_family);
 | 
			
		||||
 | 
			
		||||
    result.portno = ntohs(input.sin_port);
 | 
			
		||||
 | 
			
		||||
@@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
 | 
			
		||||
short TranslatePollEvents(PollEvents events) {
 | 
			
		||||
    short result = 0;
 | 
			
		||||
 | 
			
		||||
    if (True(events & PollEvents::In)) {
 | 
			
		||||
        events &= ~PollEvents::In;
 | 
			
		||||
        result |= POLLIN;
 | 
			
		||||
    }
 | 
			
		||||
    if (True(events & PollEvents::Pri)) {
 | 
			
		||||
        events &= ~PollEvents::Pri;
 | 
			
		||||
    const auto translate = [&result, &events](PollEvents guest, short host) {
 | 
			
		||||
        if (True(events & guest)) {
 | 
			
		||||
            events &= ~guest;
 | 
			
		||||
            result |= host;
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    translate(PollEvents::In, POLLIN);
 | 
			
		||||
    translate(PollEvents::Pri, POLLPRI);
 | 
			
		||||
    translate(PollEvents::Out, POLLOUT);
 | 
			
		||||
    translate(PollEvents::Err, POLLERR);
 | 
			
		||||
    translate(PollEvents::Hup, POLLHUP);
 | 
			
		||||
    translate(PollEvents::Nval, POLLNVAL);
 | 
			
		||||
    translate(PollEvents::RdNorm, POLLRDNORM);
 | 
			
		||||
    translate(PollEvents::RdBand, POLLRDBAND);
 | 
			
		||||
    translate(PollEvents::WrBand, POLLWRBAND);
 | 
			
		||||
 | 
			
		||||
#ifdef _WIN32
 | 
			
		||||
        LOG_WARNING(Service, "Winsock doesn't support POLLPRI");
 | 
			
		||||
#else
 | 
			
		||||
        result |= POLLPRI;
 | 
			
		||||
    short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM;
 | 
			
		||||
    // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input.
 | 
			
		||||
    if (result & ~allowed_events) {
 | 
			
		||||
        LOG_DEBUG(Network,
 | 
			
		||||
                  "Removing WSAPoll input events 0x{:x} because Windows doesn't support them",
 | 
			
		||||
                  result & ~allowed_events);
 | 
			
		||||
    }
 | 
			
		||||
    result &= allowed_events;
 | 
			
		||||
#endif
 | 
			
		||||
    }
 | 
			
		||||
    if (True(events & PollEvents::Out)) {
 | 
			
		||||
        events &= ~PollEvents::Out;
 | 
			
		||||
        result |= POLLOUT;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events);
 | 
			
		||||
 | 
			
		||||
@@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) {
 | 
			
		||||
    translate(POLLOUT, PollEvents::Out);
 | 
			
		||||
    translate(POLLERR, PollEvents::Err);
 | 
			
		||||
    translate(POLLHUP, PollEvents::Hup);
 | 
			
		||||
    translate(POLLNVAL, PollEvents::Nval);
 | 
			
		||||
    translate(POLLRDNORM, PollEvents::RdNorm);
 | 
			
		||||
    translate(POLLRDBAND, PollEvents::RdBand);
 | 
			
		||||
    translate(POLLWRBAND, PollEvents::WrBand);
 | 
			
		||||
 | 
			
		||||
    UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents);
 | 
			
		||||
 | 
			
		||||
@@ -360,12 +480,51 @@ std::optional<IPv4Address> GetHostIPv4Address() {
 | 
			
		||||
        return {};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::array<char, 16> ip_addr = {};
 | 
			
		||||
    ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) !=
 | 
			
		||||
           nullptr);
 | 
			
		||||
    return TranslateIPv4(network_interface->ip_address);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string IPv4AddressToString(IPv4Address ip_addr) {
 | 
			
		||||
    std::array<char, INET_ADDRSTRLEN> buf = {};
 | 
			
		||||
    ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data());
 | 
			
		||||
    return std::string(buf.data());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
u32 IPv4AddressToInteger(IPv4Address ip_addr) {
 | 
			
		||||
    return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 |
 | 
			
		||||
           static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
 | 
			
		||||
    const std::string& host, const std::optional<std::string>& service) {
 | 
			
		||||
    addrinfo hints{};
 | 
			
		||||
    hints.ai_family = AF_INET; // Switch only supports IPv4.
 | 
			
		||||
    addrinfo* addrinfo;
 | 
			
		||||
    s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr,
 | 
			
		||||
                              &hints, &addrinfo);
 | 
			
		||||
    if (gai_err != 0) {
 | 
			
		||||
        return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err));
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<AddrInfo> ret;
 | 
			
		||||
    for (auto* current = addrinfo; current; current = current->ai_next) {
 | 
			
		||||
        // We should only get AF_INET results due to the hints value.
 | 
			
		||||
        ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET &&
 | 
			
		||||
                              addrinfo->ai_addrlen == sizeof(sockaddr_in),
 | 
			
		||||
                          continue;);
 | 
			
		||||
 | 
			
		||||
        AddrInfo& out = ret.emplace_back();
 | 
			
		||||
        out.family = TranslateDomainFromNative(current->ai_family);
 | 
			
		||||
        out.socket_type = TranslateTypeFromNative(current->ai_socktype);
 | 
			
		||||
        out.protocol = TranslateProtocolFromNative(current->ai_protocol);
 | 
			
		||||
        out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr),
 | 
			
		||||
                                         current->ai_addrlen);
 | 
			
		||||
        if (current->ai_canonname != nullptr) {
 | 
			
		||||
            out.canon_name = current->ai_canonname;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    freeaddrinfo(addrinfo);
 | 
			
		||||
    return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) {
 | 
			
		||||
    const size_t num = pollfds.size();
 | 
			
		||||
 | 
			
		||||
@@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
 | 
			
		||||
std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) {
 | 
			
		||||
    T value{};
 | 
			
		||||
    socklen_t len = sizeof(value);
 | 
			
		||||
    const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
 | 
			
		||||
    if (result != SOCKET_ERROR) {
 | 
			
		||||
        ASSERT(len == sizeof(value));
 | 
			
		||||
        return {value, Errno::SUCCESS};
 | 
			
		||||
    }
 | 
			
		||||
    return {value, GetAndLogLastError()};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) {
 | 
			
		||||
    const int result =
 | 
			
		||||
        setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
 | 
			
		||||
        setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
 | 
			
		||||
    if (result != SOCKET_ERROR) {
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    }
 | 
			
		||||
@@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
 | 
			
		||||
    fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol));
 | 
			
		||||
    fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type),
 | 
			
		||||
                TranslateProtocolToNative(protocol));
 | 
			
		||||
    if (fd != INVALID_SOCKET) {
 | 
			
		||||
        return Errno::SUCCESS;
 | 
			
		||||
    }
 | 
			
		||||
@@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() {
 | 
			
		||||
    sockaddr addr;
 | 
			
		||||
    sockaddr_in addr;
 | 
			
		||||
    socklen_t addrlen = sizeof(addr);
 | 
			
		||||
    const SOCKET new_socket = accept(fd, &addr, &addrlen);
 | 
			
		||||
    const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen);
 | 
			
		||||
 | 
			
		||||
    if (new_socket == INVALID_SOCKET) {
 | 
			
		||||
        return {AcceptResult{}, GetAndLogLastError()};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ASSERT(addrlen == sizeof(sockaddr_in));
 | 
			
		||||
 | 
			
		||||
    AcceptResult result{
 | 
			
		||||
        .socket = std::make_unique<Socket>(new_socket),
 | 
			
		||||
        .sockaddr_in = TranslateToSockAddrIn(addr),
 | 
			
		||||
        .sockaddr_in = TranslateToSockAddrIn(addr, addrlen),
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    return {std::move(result), Errno::SUCCESS};
 | 
			
		||||
@@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<SockAddrIn, Errno> Socket::GetPeerName() {
 | 
			
		||||
    sockaddr addr;
 | 
			
		||||
    sockaddr_in addr;
 | 
			
		||||
    socklen_t addrlen = sizeof(addr);
 | 
			
		||||
    if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) {
 | 
			
		||||
    if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
 | 
			
		||||
        return {SockAddrIn{}, GetAndLogLastError()};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ASSERT(addrlen == sizeof(sockaddr_in));
 | 
			
		||||
    return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
 | 
			
		||||
    return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<SockAddrIn, Errno> Socket::GetSockName() {
 | 
			
		||||
    sockaddr addr;
 | 
			
		||||
    sockaddr_in addr;
 | 
			
		||||
    socklen_t addrlen = sizeof(addr);
 | 
			
		||||
    if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) {
 | 
			
		||||
    if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
 | 
			
		||||
        return {SockAddrIn{}, GetAndLogLastError()};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ASSERT(addrlen == sizeof(sockaddr_in));
 | 
			
		||||
    return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
 | 
			
		||||
    return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Errno Socket::Bind(SockAddrIn addr) {
 | 
			
		||||
@@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) {
 | 
			
		||||
    return GetAndLogLastError();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
 | 
			
		||||
std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) {
 | 
			
		||||
    ASSERT(flags == 0);
 | 
			
		||||
    ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
 | 
			
		||||
 | 
			
		||||
@@ -532,21 +700,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
 | 
			
		||||
    return {-1, GetAndLogLastError()};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) {
 | 
			
		||||
std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
 | 
			
		||||
    ASSERT(flags == 0);
 | 
			
		||||
    ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
 | 
			
		||||
 | 
			
		||||
    sockaddr addr_in{};
 | 
			
		||||
    sockaddr_in addr_in{};
 | 
			
		||||
    socklen_t addrlen = sizeof(addr_in);
 | 
			
		||||
    socklen_t* const p_addrlen = addr ? &addrlen : nullptr;
 | 
			
		||||
    sockaddr* const p_addr_in = addr ? &addr_in : nullptr;
 | 
			
		||||
    sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr;
 | 
			
		||||
 | 
			
		||||
    const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()),
 | 
			
		||||
                                 static_cast<int>(message.size()), 0, p_addr_in, p_addrlen);
 | 
			
		||||
    if (result != SOCKET_ERROR) {
 | 
			
		||||
        if (addr) {
 | 
			
		||||
            ASSERT(addrlen == sizeof(addr_in));
 | 
			
		||||
            *addr = TranslateToSockAddrIn(addr_in);
 | 
			
		||||
            *addr = TranslateToSockAddrIn(addr_in, addrlen);
 | 
			
		||||
        }
 | 
			
		||||
        return {static_cast<s32>(result), Errno::SUCCESS};
 | 
			
		||||
    }
 | 
			
		||||
@@ -597,6 +764,11 @@ Errno Socket::Close() {
 | 
			
		||||
    return Errno::SUCCESS;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<Errno, Errno> Socket::GetPendingError() {
 | 
			
		||||
    auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR);
 | 
			
		||||
    return {TranslateNativeError(pending_err), getsockopt_err};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Errno Socket::SetLinger(bool enable, u32 linger) {
 | 
			
		||||
    return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "common/common_funcs.h"
 | 
			
		||||
#include "common/common_types.h"
 | 
			
		||||
@@ -16,6 +17,11 @@
 | 
			
		||||
#include <netinet/in.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace Common {
 | 
			
		||||
template <typename T, typename E>
 | 
			
		||||
class Expected;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace Network {
 | 
			
		||||
 | 
			
		||||
class SocketBase;
 | 
			
		||||
@@ -36,6 +42,26 @@ enum class Errno {
 | 
			
		||||
    NETUNREACH,
 | 
			
		||||
    TIMEDOUT,
 | 
			
		||||
    MSGSIZE,
 | 
			
		||||
    INPROGRESS,
 | 
			
		||||
    OTHER,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
enum class GetAddrInfoError {
 | 
			
		||||
    SUCCESS,
 | 
			
		||||
    ADDRFAMILY,
 | 
			
		||||
    AGAIN,
 | 
			
		||||
    BADFLAGS,
 | 
			
		||||
    FAIL,
 | 
			
		||||
    FAMILY,
 | 
			
		||||
    MEMORY,
 | 
			
		||||
    NODATA,
 | 
			
		||||
    NONAME,
 | 
			
		||||
    SERVICE,
 | 
			
		||||
    SOCKTYPE,
 | 
			
		||||
    SYSTEM,
 | 
			
		||||
    BADHINTS,
 | 
			
		||||
    PROTOCOL,
 | 
			
		||||
    OVERFLOW_,
 | 
			
		||||
    OTHER,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -49,6 +75,9 @@ enum class PollEvents : u16 {
 | 
			
		||||
    Err = 1 << 3,
 | 
			
		||||
    Hup = 1 << 4,
 | 
			
		||||
    Nval = 1 << 5,
 | 
			
		||||
    RdNorm = 1 << 6,
 | 
			
		||||
    RdBand = 1 << 7,
 | 
			
		||||
    WrBand = 1 << 8,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
 | 
			
		||||
@@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) {
 | 
			
		||||
/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array
 | 
			
		||||
std::optional<IPv4Address> GetHostIPv4Address();
 | 
			
		||||
 | 
			
		||||
std::string IPv4AddressToString(IPv4Address ip_addr);
 | 
			
		||||
u32 IPv4AddressToInteger(IPv4Address ip_addr);
 | 
			
		||||
 | 
			
		||||
// named to avoid name collision with Windows macro
 | 
			
		||||
Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
 | 
			
		||||
    const std::string& host, const std::optional<std::string>& service);
 | 
			
		||||
 | 
			
		||||
} // namespace Network
 | 
			
		||||
 
 | 
			
		||||
@@ -98,7 +98,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) {
 | 
			
		||||
    return Errno::SUCCESS;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
 | 
			
		||||
std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) {
 | 
			
		||||
    LOG_WARNING(Network, "(STUBBED) called");
 | 
			
		||||
    ASSERT(flags == 0);
 | 
			
		||||
    ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
 | 
			
		||||
@@ -106,7 +106,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
 | 
			
		||||
    return {static_cast<s32>(0), Errno::SUCCESS};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) {
 | 
			
		||||
std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
 | 
			
		||||
    ASSERT(flags == 0);
 | 
			
		||||
    ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
 | 
			
		||||
 | 
			
		||||
@@ -140,8 +140,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message,
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message,
 | 
			
		||||
                                                 SockAddrIn* addr, std::size_t max_length) {
 | 
			
		||||
std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
 | 
			
		||||
                                                 std::size_t max_length) {
 | 
			
		||||
    ProxyPacket& packet = received_packets.front();
 | 
			
		||||
    if (addr) {
 | 
			
		||||
        addr->family = Domain::INET;
 | 
			
		||||
@@ -153,10 +153,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
 | 
			
		||||
    std::size_t read_bytes;
 | 
			
		||||
    if (packet.data.size() > max_length) {
 | 
			
		||||
        read_bytes = max_length;
 | 
			
		||||
        message.clear();
 | 
			
		||||
        std::copy(packet.data.begin(), packet.data.begin() + read_bytes,
 | 
			
		||||
                  std::back_inserter(message));
 | 
			
		||||
        message.resize(max_length);
 | 
			
		||||
        memcpy(message.data(), packet.data.data(), max_length);
 | 
			
		||||
 | 
			
		||||
        if (protocol == Protocol::UDP) {
 | 
			
		||||
            if (!peek) {
 | 
			
		||||
@@ -171,9 +168,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        read_bytes = packet.data.size();
 | 
			
		||||
        message.clear();
 | 
			
		||||
        std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message));
 | 
			
		||||
        message.resize(max_length);
 | 
			
		||||
        memcpy(message.data(), packet.data.data(), read_bytes);
 | 
			
		||||
        if (!peek) {
 | 
			
		||||
            received_packets.pop();
 | 
			
		||||
        }
 | 
			
		||||
@@ -293,6 +288,11 @@ Errno ProxySocket::SetNonBlock(bool enable) {
 | 
			
		||||
    return Errno::SUCCESS;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<Errno, Errno> ProxySocket::GetPendingError() {
 | 
			
		||||
    LOG_DEBUG(Network, "(STUBBED) called");
 | 
			
		||||
    return {Errno::SUCCESS, Errno::SUCCESS};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ProxySocket::IsOpened() const {
 | 
			
		||||
    return fd != INVALID_SOCKET;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,11 +39,11 @@ public:
 | 
			
		||||
 | 
			
		||||
    Errno Shutdown(ShutdownHow how) override;
 | 
			
		||||
 | 
			
		||||
    std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override;
 | 
			
		||||
    std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
 | 
			
		||||
 | 
			
		||||
    std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override;
 | 
			
		||||
    std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
 | 
			
		||||
 | 
			
		||||
    std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr,
 | 
			
		||||
    std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
 | 
			
		||||
                                        std::size_t max_length);
 | 
			
		||||
 | 
			
		||||
    std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
 | 
			
		||||
@@ -74,6 +74,8 @@ public:
 | 
			
		||||
    template <typename T>
 | 
			
		||||
    Errno SetSockOpt(SOCKET fd, int option, T value);
 | 
			
		||||
 | 
			
		||||
    std::pair<Errno, Errno> GetPendingError() override;
 | 
			
		||||
 | 
			
		||||
    bool IsOpened() const override;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
 
 | 
			
		||||
@@ -59,10 +59,9 @@ public:
 | 
			
		||||
 | 
			
		||||
    virtual Errno Shutdown(ShutdownHow how) = 0;
 | 
			
		||||
 | 
			
		||||
    virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0;
 | 
			
		||||
    virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0;
 | 
			
		||||
 | 
			
		||||
    virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message,
 | 
			
		||||
                                           SockAddrIn* addr) = 0;
 | 
			
		||||
    virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0;
 | 
			
		||||
 | 
			
		||||
    virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0;
 | 
			
		||||
 | 
			
		||||
@@ -87,6 +86,8 @@ public:
 | 
			
		||||
 | 
			
		||||
    virtual Errno SetNonBlock(bool enable) = 0;
 | 
			
		||||
 | 
			
		||||
    virtual std::pair<Errno, Errno> GetPendingError() = 0;
 | 
			
		||||
 | 
			
		||||
    virtual bool IsOpened() const = 0;
 | 
			
		||||
 | 
			
		||||
    virtual void HandleProxyPacket(const ProxyPacket& packet) = 0;
 | 
			
		||||
@@ -126,9 +127,9 @@ public:
 | 
			
		||||
 | 
			
		||||
    Errno Shutdown(ShutdownHow how) override;
 | 
			
		||||
 | 
			
		||||
    std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override;
 | 
			
		||||
    std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
 | 
			
		||||
 | 
			
		||||
    std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override;
 | 
			
		||||
    std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
 | 
			
		||||
 | 
			
		||||
    std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
 | 
			
		||||
 | 
			
		||||
@@ -156,6 +157,11 @@ public:
 | 
			
		||||
    template <typename T>
 | 
			
		||||
    Errno SetSockOpt(SOCKET fd, int option, T value);
 | 
			
		||||
 | 
			
		||||
    std::pair<Errno, Errno> GetPendingError() override;
 | 
			
		||||
 | 
			
		||||
    template <typename T>
 | 
			
		||||
    std::pair<T, Errno> GetSockOpt(SOCKET fd, int option);
 | 
			
		||||
 | 
			
		||||
    bool IsOpened() const override;
 | 
			
		||||
 | 
			
		||||
    void HandleProxyPacket(const ProxyPacket& packet) override;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user