From d765a73a53105f09fda47f58b8a5cbb902e398b4 Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Tue, 29 Jan 2019 17:28:50 -0500 Subject: [PATCH] Scripting: reimplement protocol over plan UDP using boost::asio --- dist/scripting/citra.py | 18 +++---- src/core/CMakeLists.txt | 4 +- src/core/rpc/rpc_server.h | 3 ++ src/core/rpc/server.cpp | 13 +++-- src/core/rpc/server.h | 11 ++-- src/core/rpc/udp_server.cpp | 100 ++++++++++++++++++++++++++++++++++++ src/core/rpc/udp_server.h | 24 +++++++++ src/core/rpc/zmq_server.cpp | 79 ---------------------------- src/core/rpc/zmq_server.h | 34 ------------ 9 files changed, 152 insertions(+), 134 deletions(-) create mode 100644 src/core/rpc/udp_server.cpp create mode 100644 src/core/rpc/udp_server.h delete mode 100644 src/core/rpc/zmq_server.cpp delete mode 100644 src/core/rpc/zmq_server.h diff --git a/dist/scripting/citra.py b/dist/scripting/citra.py index be6068685..507662033 100644 --- a/dist/scripting/citra.py +++ b/dist/scripting/citra.py @@ -1,22 +1,22 @@ -import zmq import struct import random import enum +import socket CURRENT_REQUEST_VERSION = 1 MAX_REQUEST_DATA_SIZE = 32 +MAX_PACKET_SIZE = 48 class RequestType(enum.IntEnum): ReadMemory = 1, WriteMemory = 2 -CITRA_PORT = "45987" +CITRA_PORT = 45987 class Citra: def __init__(self, address="127.0.0.1", port=CITRA_PORT): - self.context = zmq.Context() - self.socket = self.context.socket(zmq.REQ) - self.socket.connect("tcp://" + address + ":" + port) + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.address = address def is_connected(self): return self.socket is not None @@ -45,9 +45,9 @@ class Citra: request_data = struct.pack("II", read_address, temp_read_size) request, request_id = self._generate_header(RequestType.ReadMemory, len(request_data)) request += request_data - self.socket.send(request) + self.socket.sendto(request, (self.address, CITRA_PORT)) - raw_reply = self.socket.recv() + raw_reply = self.socket.recv(MAX_PACKET_SIZE) reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.ReadMemory) if reply_data: @@ -77,9 +77,9 @@ class Citra: request_data += write_contents[:temp_write_size] request, request_id = self._generate_header(RequestType.WriteMemory, len(request_data)) request += request_data - self.socket.send(request) + self.socket.sendto(request, (self.address, CITRA_PORT)) - raw_reply = self.socket.recv() + raw_reply = self.socket.recv(MAX_PACKET_SIZE) reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.WriteMemory) if None != reply_data: diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 1e9f0248b..2098a296b 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -439,8 +439,8 @@ if (ENABLE_SCRIPTING) rpc/rpc_server.h rpc/server.cpp rpc/server.h - rpc/zmq_server.cpp - rpc/zmq_server.h + rpc/udp_server.cpp + rpc/udp_server.h ) endif() diff --git a/src/core/rpc/rpc_server.h b/src/core/rpc/rpc_server.h index bb57bcdae..83fe4f60c 100644 --- a/src/core/rpc/rpc_server.h +++ b/src/core/rpc/rpc_server.h @@ -13,6 +13,9 @@ namespace RPC { +class Packet; +struct PacketHeader; + class RPCServer { public: RPCServer(); diff --git a/src/core/rpc/server.cpp b/src/core/rpc/server.cpp index b3e66d6ff..3474fa7e1 100644 --- a/src/core/rpc/server.cpp +++ b/src/core/rpc/server.cpp @@ -1,27 +1,30 @@ #include - #include "core/core.h" +#include "core/rpc/packet.h" #include "core/rpc/rpc_server.h" #include "core/rpc/server.h" +#include "core/rpc/udp_server.h" namespace RPC { Server::Server(RPCServer& rpc_server) : rpc_server(rpc_server) {} +Server::~Server() = default; + void Server::Start() { - const auto callback = [this](std::unique_ptr new_request) { + const auto callback = [this](std::unique_ptr new_request) { NewRequestCallback(std::move(new_request)); }; try { - zmq_server = std::make_unique(callback); + udp_server = std::make_unique(callback); } catch (...) { - LOG_ERROR(RPC_Server, "Error starting ZeroMQ server"); + LOG_ERROR(RPC_Server, "Error starting UDP server"); } } void Server::Stop() { - zmq_server.reset(); + udp_server.reset(); } void Server::NewRequestCallback(std::unique_ptr new_request) { diff --git a/src/core/rpc/server.h b/src/core/rpc/server.h index 2dfad2ef7..c9f27cd5c 100644 --- a/src/core/rpc/server.h +++ b/src/core/rpc/server.h @@ -4,24 +4,25 @@ #pragma once -#include "core/rpc/packet.h" -#include "core/rpc/zmq_server.h" +#include namespace RPC { class RPCServer; -class ZMQServer; +class UDPServer; +class Packet; class Server { public: Server(RPCServer& rpc_server); + ~Server(); void Start(); void Stop(); - void NewRequestCallback(std::unique_ptr new_request); + void NewRequestCallback(std::unique_ptr new_request); private: RPCServer& rpc_server; - std::unique_ptr zmq_server; + std::unique_ptr udp_server; }; } // namespace RPC diff --git a/src/core/rpc/udp_server.cpp b/src/core/rpc/udp_server.cpp new file mode 100644 index 000000000..185450f5e --- /dev/null +++ b/src/core/rpc/udp_server.cpp @@ -0,0 +1,100 @@ +// Copyright 2019 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include +#include "common/common_types.h" +#include "common/logging/log.h" +#include "core/rpc/packet.h" +#include "core/rpc/udp_server.h" + +namespace RPC { + +class UDPServer::Impl { +public: + explicit Impl(std::function)> new_request_callback) + // Use a random high port + // TODO: Make configurable or increment port number on failure + : socket(io_context, boost::asio::ip::udp::endpoint(boost::asio::ip::udp::v4(), 45987)), + new_request_callback(std::move(new_request_callback)) { + + StartReceive(); + worker_thread = std::thread([this] { + io_context.run(); + this->new_request_callback(nullptr); + }); + } + + ~Impl() { + io_context.stop(); + worker_thread.join(); + } + +private: + void StartReceive() { + socket.async_receive_from(boost::asio::buffer(request_buffer), remote_endpoint, + [this](const boost::system::error_code& error, std::size_t size) { + HandleReceive(error, size); + }); + } + + void HandleReceive(const boost::system::error_code& error, std::size_t size) { + if (error) { + LOG_WARNING(RPC_Server, "Failed to receive data on UDP socket: {}", error.message()); + } else if (size >= MIN_PACKET_SIZE && size <= MAX_PACKET_SIZE) { + PacketHeader header; + std::memcpy(&header, request_buffer.data(), sizeof(header)); + if ((size - MIN_PACKET_SIZE) == header.packet_size) { + u8* data = request_buffer.data() + MIN_PACKET_SIZE; + std::function send_reply_callback = + std::bind(&Impl::SendReply, this, remote_endpoint, std::placeholders::_1); + std::unique_ptr new_packet = + std::make_unique(header, data, send_reply_callback); + + // Send the request to the upper layer for handling + new_request_callback(std::move(new_packet)); + } + } else { + LOG_WARNING(RPC_Server, "Received message with wrong size: {}", size); + } + StartReceive(); + } + + void SendReply(boost::asio::ip::udp::endpoint endpoint, Packet& reply_packet) { + std::vector reply_buffer(MIN_PACKET_SIZE + reply_packet.GetPacketDataSize()); + auto reply_header = reply_packet.GetHeader(); + + std::memcpy(reply_buffer.data(), &reply_header, sizeof(reply_header)); + std::memcpy(reply_buffer.data() + (4 * sizeof(u32)), reply_packet.GetPacketData().data(), + reply_packet.GetPacketDataSize()); + + boost::system::error_code error; + socket.send_to(boost::asio::buffer(reply_buffer), endpoint, 0, error); + + if (error) { + LOG_WARNING(RPC_Server, "Failed to send reply: {}", error.message()); + } else { + LOG_INFO(RPC_Server, "Sent reply version({}) id=({}) type=({}) size=({})", + reply_packet.GetVersion(), reply_packet.GetId(), + static_cast(reply_packet.GetPacketType()), + reply_packet.GetPacketDataSize()); + } + } + + std::thread worker_thread; + + boost::asio::io_context io_context; + boost::asio::ip::udp::socket socket; + std::array request_buffer; + boost::asio::ip::udp::endpoint remote_endpoint; + + std::function)> new_request_callback; +}; + +UDPServer::UDPServer(std::function)> new_request_callback) + : impl(std::make_unique(new_request_callback)) {} + +UDPServer::~UDPServer() = default; + +} // namespace RPC diff --git a/src/core/rpc/udp_server.h b/src/core/rpc/udp_server.h new file mode 100644 index 000000000..f4ff2ad62 --- /dev/null +++ b/src/core/rpc/udp_server.h @@ -0,0 +1,24 @@ +// Copyright 2019 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include + +namespace RPC { + +class Packet; + +class UDPServer { +public: + explicit UDPServer(std::function)> new_request_callback); + ~UDPServer(); + +private: + class Impl; + std::unique_ptr impl; +}; + +} // namespace RPC diff --git a/src/core/rpc/zmq_server.cpp b/src/core/rpc/zmq_server.cpp deleted file mode 100644 index 47885973c..000000000 --- a/src/core/rpc/zmq_server.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include "common/common_types.h" -#include "core/core.h" -#include "core/rpc/packet.h" -#include "core/rpc/zmq_server.h" - -namespace RPC { - -ZMQServer::ZMQServer(std::function)> new_request_callback) - : zmq_context(std::move(std::make_unique(1))), - zmq_socket(std::move(std::make_unique(*zmq_context, ZMQ_REP))), - new_request_callback(std::move(new_request_callback)) { - // Use a random high port - // TODO: Make configurable or increment port number on failure - zmq_socket->bind("tcp://127.0.0.1:45987"); - LOG_INFO(RPC_Server, "ZeroMQ listening on port 45987"); - - worker_thread = std::thread(&ZMQServer::WorkerLoop, this); -} - -ZMQServer::~ZMQServer() { - // Triggering the zmq_context destructor will cancel - // any blocking calls to zmq_socket->recv() - running = false; - zmq_context.reset(); - worker_thread.join(); - - LOG_INFO(RPC_Server, "ZeroMQ stopped"); -} - -void ZMQServer::WorkerLoop() { - zmq::message_t request; - while (running) { - try { - if (zmq_socket->recv(&request, 0)) { - if (request.size() >= MIN_PACKET_SIZE && request.size() <= MAX_PACKET_SIZE) { - u8* request_buffer = static_cast(request.data()); - PacketHeader header; - std::memcpy(&header, request_buffer, sizeof(header)); - if ((request.size() - MIN_PACKET_SIZE) == header.packet_size) { - u8* data = request_buffer + MIN_PACKET_SIZE; - std::function send_reply_callback = - std::bind(&ZMQServer::SendReply, this, std::placeholders::_1); - std::unique_ptr new_packet = - std::make_unique(header, data, send_reply_callback); - - // Send the request to the upper layer for handling - new_request_callback(std::move(new_packet)); - } - } - } - } catch (...) { - LOG_WARNING(RPC_Server, "Failed to receive data on ZeroMQ socket"); - } - } - std::unique_ptr end_packet = nullptr; - new_request_callback(std::move(end_packet)); - // Destroying the socket must be done by this thread. - zmq_socket.reset(); -} - -void ZMQServer::SendReply(Packet& reply_packet) { - if (running) { - auto reply_buffer = - std::make_unique(MIN_PACKET_SIZE + reply_packet.GetPacketDataSize()); - auto reply_header = reply_packet.GetHeader(); - - std::memcpy(reply_buffer.get(), &reply_header, sizeof(reply_header)); - std::memcpy(reply_buffer.get() + (4 * sizeof(u32)), reply_packet.GetPacketData().data(), - reply_packet.GetPacketDataSize()); - - zmq_socket->send(reply_buffer.get(), MIN_PACKET_SIZE + reply_packet.GetPacketDataSize()); - - LOG_INFO(RPC_Server, "Sent reply version({}) id=({}) type=({}) size=({})", - reply_packet.GetVersion(), reply_packet.GetId(), - static_cast(reply_packet.GetPacketType()), reply_packet.GetPacketDataSize()); - } -} - -}; // namespace RPC diff --git a/src/core/rpc/zmq_server.h b/src/core/rpc/zmq_server.h deleted file mode 100644 index 784fccf5a..000000000 --- a/src/core/rpc/zmq_server.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2018 Citra Emulator Project -// Licensed under GPLv2 or any later version -// Refer to the license.txt file included. - -#pragma once - -#include -#include -#define ZMQ_STATIC -#include - -namespace RPC { - -class Packet; - -class ZMQServer { -public: - explicit ZMQServer(std::function)> new_request_callback); - ~ZMQServer(); - -private: - void WorkerLoop(); - void SendReply(Packet& request); - - std::thread worker_thread; - std::atomic_bool running = true; - - std::unique_ptr zmq_context; - std::unique_ptr zmq_socket; - - std::function)> new_request_callback; -}; - -} // namespace RPC