Scripting: reimplement protocol over plan UDP using boost::asio
This commit is contained in:
		
							
								
								
									
										18
									
								
								dist/scripting/citra.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										18
									
								
								dist/scripting/citra.py
									
									
									
									
										vendored
									
									
								
							| @@ -1,22 +1,22 @@ | |||||||
| import zmq |  | ||||||
| import struct | import struct | ||||||
| import random | import random | ||||||
| import enum | import enum | ||||||
|  | import socket | ||||||
|  |  | ||||||
| CURRENT_REQUEST_VERSION = 1 | CURRENT_REQUEST_VERSION = 1 | ||||||
| MAX_REQUEST_DATA_SIZE = 32 | MAX_REQUEST_DATA_SIZE = 32 | ||||||
|  | MAX_PACKET_SIZE = 48 | ||||||
|  |  | ||||||
| class RequestType(enum.IntEnum): | class RequestType(enum.IntEnum): | ||||||
|     ReadMemory = 1, |     ReadMemory = 1, | ||||||
|     WriteMemory = 2 |     WriteMemory = 2 | ||||||
|  |  | ||||||
| CITRA_PORT = "45987" | CITRA_PORT = 45987 | ||||||
|  |  | ||||||
| class Citra: | class Citra: | ||||||
|     def __init__(self, address="127.0.0.1", port=CITRA_PORT): |     def __init__(self, address="127.0.0.1", port=CITRA_PORT): | ||||||
|         self.context = zmq.Context() |         self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||||||
|         self.socket = self.context.socket(zmq.REQ) |         self.address = address | ||||||
|         self.socket.connect("tcp://" + address + ":" + port) |  | ||||||
|  |  | ||||||
|     def is_connected(self): |     def is_connected(self): | ||||||
|         return self.socket is not None |         return self.socket is not None | ||||||
| @@ -45,9 +45,9 @@ class Citra: | |||||||
|             request_data = struct.pack("II", read_address, temp_read_size) |             request_data = struct.pack("II", read_address, temp_read_size) | ||||||
|             request, request_id = self._generate_header(RequestType.ReadMemory, len(request_data)) |             request, request_id = self._generate_header(RequestType.ReadMemory, len(request_data)) | ||||||
|             request += request_data |             request += request_data | ||||||
|             self.socket.send(request) |             self.socket.sendto(request, (self.address, CITRA_PORT)) | ||||||
|  |  | ||||||
|             raw_reply = self.socket.recv() |             raw_reply = self.socket.recv(MAX_PACKET_SIZE) | ||||||
|             reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.ReadMemory) |             reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.ReadMemory) | ||||||
|  |  | ||||||
|             if reply_data: |             if reply_data: | ||||||
| @@ -77,9 +77,9 @@ class Citra: | |||||||
|             request_data += write_contents[:temp_write_size] |             request_data += write_contents[:temp_write_size] | ||||||
|             request, request_id = self._generate_header(RequestType.WriteMemory, len(request_data)) |             request, request_id = self._generate_header(RequestType.WriteMemory, len(request_data)) | ||||||
|             request += request_data |             request += request_data | ||||||
|             self.socket.send(request) |             self.socket.sendto(request, (self.address, CITRA_PORT)) | ||||||
|  |  | ||||||
|             raw_reply = self.socket.recv() |             raw_reply = self.socket.recv(MAX_PACKET_SIZE) | ||||||
|             reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.WriteMemory) |             reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.WriteMemory) | ||||||
|  |  | ||||||
|             if None != reply_data: |             if None != reply_data: | ||||||
|   | |||||||
| @@ -439,8 +439,8 @@ if (ENABLE_SCRIPTING) | |||||||
|         rpc/rpc_server.h |         rpc/rpc_server.h | ||||||
|         rpc/server.cpp |         rpc/server.cpp | ||||||
|         rpc/server.h |         rpc/server.h | ||||||
|         rpc/zmq_server.cpp |         rpc/udp_server.cpp | ||||||
|         rpc/zmq_server.h |         rpc/udp_server.h | ||||||
|     ) |     ) | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,6 +13,9 @@ | |||||||
|  |  | ||||||
| namespace RPC { | namespace RPC { | ||||||
|  |  | ||||||
|  | class Packet; | ||||||
|  | struct PacketHeader; | ||||||
|  |  | ||||||
| class RPCServer { | class RPCServer { | ||||||
| public: | public: | ||||||
|     RPCServer(); |     RPCServer(); | ||||||
|   | |||||||
| @@ -1,27 +1,30 @@ | |||||||
| #include <functional> | #include <functional> | ||||||
|  |  | ||||||
| #include "core/core.h" | #include "core/core.h" | ||||||
|  | #include "core/rpc/packet.h" | ||||||
| #include "core/rpc/rpc_server.h" | #include "core/rpc/rpc_server.h" | ||||||
| #include "core/rpc/server.h" | #include "core/rpc/server.h" | ||||||
|  | #include "core/rpc/udp_server.h" | ||||||
|  |  | ||||||
| namespace RPC { | namespace RPC { | ||||||
|  |  | ||||||
| Server::Server(RPCServer& rpc_server) : rpc_server(rpc_server) {} | Server::Server(RPCServer& rpc_server) : rpc_server(rpc_server) {} | ||||||
|  |  | ||||||
|  | Server::~Server() = default; | ||||||
|  |  | ||||||
| void Server::Start() { | void Server::Start() { | ||||||
|     const auto callback = [this](std::unique_ptr<RPC::Packet> new_request) { |     const auto callback = [this](std::unique_ptr<Packet> new_request) { | ||||||
|         NewRequestCallback(std::move(new_request)); |         NewRequestCallback(std::move(new_request)); | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     try { |     try { | ||||||
|         zmq_server = std::make_unique<ZMQServer>(callback); |         udp_server = std::make_unique<UDPServer>(callback); | ||||||
|     } catch (...) { |     } catch (...) { | ||||||
|         LOG_ERROR(RPC_Server, "Error starting ZeroMQ server"); |         LOG_ERROR(RPC_Server, "Error starting UDP server"); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void Server::Stop() { | void Server::Stop() { | ||||||
|     zmq_server.reset(); |     udp_server.reset(); | ||||||
| } | } | ||||||
|  |  | ||||||
| void Server::NewRequestCallback(std::unique_ptr<RPC::Packet> new_request) { | void Server::NewRequestCallback(std::unique_ptr<RPC::Packet> new_request) { | ||||||
|   | |||||||
| @@ -4,24 +4,25 @@ | |||||||
|  |  | ||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
| #include "core/rpc/packet.h" | #include <memory> | ||||||
| #include "core/rpc/zmq_server.h" |  | ||||||
|  |  | ||||||
| namespace RPC { | namespace RPC { | ||||||
|  |  | ||||||
| class RPCServer; | class RPCServer; | ||||||
| class ZMQServer; | class UDPServer; | ||||||
|  | class Packet; | ||||||
|  |  | ||||||
| class Server { | class Server { | ||||||
| public: | public: | ||||||
|     Server(RPCServer& rpc_server); |     Server(RPCServer& rpc_server); | ||||||
|  |     ~Server(); | ||||||
|     void Start(); |     void Start(); | ||||||
|     void Stop(); |     void Stop(); | ||||||
|     void NewRequestCallback(std::unique_ptr<RPC::Packet> new_request); |     void NewRequestCallback(std::unique_ptr<Packet> new_request); | ||||||
|  |  | ||||||
| private: | private: | ||||||
|     RPCServer& rpc_server; |     RPCServer& rpc_server; | ||||||
|     std::unique_ptr<ZMQServer> zmq_server; |     std::unique_ptr<UDPServer> udp_server; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| } // namespace RPC | } // namespace RPC | ||||||
|   | |||||||
							
								
								
									
										100
									
								
								src/core/rpc/udp_server.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								src/core/rpc/udp_server.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | |||||||
|  | // Copyright 2019 Citra Emulator Project | ||||||
|  | // Licensed under GPLv2 or any later version | ||||||
|  | // Refer to the license.txt file included. | ||||||
|  |  | ||||||
|  | #include <thread> | ||||||
|  | #include <boost/asio.hpp> | ||||||
|  | #include "common/common_types.h" | ||||||
|  | #include "common/logging/log.h" | ||||||
|  | #include "core/rpc/packet.h" | ||||||
|  | #include "core/rpc/udp_server.h" | ||||||
|  |  | ||||||
|  | namespace RPC { | ||||||
|  |  | ||||||
|  | class UDPServer::Impl { | ||||||
|  | public: | ||||||
|  |     explicit Impl(std::function<void(std::unique_ptr<Packet>)> new_request_callback) | ||||||
|  |         // Use a random high port | ||||||
|  |         // TODO: Make configurable or increment port number on failure | ||||||
|  |         : socket(io_context, boost::asio::ip::udp::endpoint(boost::asio::ip::udp::v4(), 45987)), | ||||||
|  |           new_request_callback(std::move(new_request_callback)) { | ||||||
|  |  | ||||||
|  |         StartReceive(); | ||||||
|  |         worker_thread = std::thread([this] { | ||||||
|  |             io_context.run(); | ||||||
|  |             this->new_request_callback(nullptr); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     ~Impl() { | ||||||
|  |         io_context.stop(); | ||||||
|  |         worker_thread.join(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  | private: | ||||||
|  |     void StartReceive() { | ||||||
|  |         socket.async_receive_from(boost::asio::buffer(request_buffer), remote_endpoint, | ||||||
|  |                                   [this](const boost::system::error_code& error, std::size_t size) { | ||||||
|  |                                       HandleReceive(error, size); | ||||||
|  |                                   }); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void HandleReceive(const boost::system::error_code& error, std::size_t size) { | ||||||
|  |         if (error) { | ||||||
|  |             LOG_WARNING(RPC_Server, "Failed to receive data on UDP socket: {}", error.message()); | ||||||
|  |         } else if (size >= MIN_PACKET_SIZE && size <= MAX_PACKET_SIZE) { | ||||||
|  |             PacketHeader header; | ||||||
|  |             std::memcpy(&header, request_buffer.data(), sizeof(header)); | ||||||
|  |             if ((size - MIN_PACKET_SIZE) == header.packet_size) { | ||||||
|  |                 u8* data = request_buffer.data() + MIN_PACKET_SIZE; | ||||||
|  |                 std::function<void(Packet&)> send_reply_callback = | ||||||
|  |                     std::bind(&Impl::SendReply, this, remote_endpoint, std::placeholders::_1); | ||||||
|  |                 std::unique_ptr<Packet> new_packet = | ||||||
|  |                     std::make_unique<Packet>(header, data, send_reply_callback); | ||||||
|  |  | ||||||
|  |                 // Send the request to the upper layer for handling | ||||||
|  |                 new_request_callback(std::move(new_packet)); | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             LOG_WARNING(RPC_Server, "Received message with wrong size: {}", size); | ||||||
|  |         } | ||||||
|  |         StartReceive(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void SendReply(boost::asio::ip::udp::endpoint endpoint, Packet& reply_packet) { | ||||||
|  |         std::vector<u8> reply_buffer(MIN_PACKET_SIZE + reply_packet.GetPacketDataSize()); | ||||||
|  |         auto reply_header = reply_packet.GetHeader(); | ||||||
|  |  | ||||||
|  |         std::memcpy(reply_buffer.data(), &reply_header, sizeof(reply_header)); | ||||||
|  |         std::memcpy(reply_buffer.data() + (4 * sizeof(u32)), reply_packet.GetPacketData().data(), | ||||||
|  |                     reply_packet.GetPacketDataSize()); | ||||||
|  |  | ||||||
|  |         boost::system::error_code error; | ||||||
|  |         socket.send_to(boost::asio::buffer(reply_buffer), endpoint, 0, error); | ||||||
|  |  | ||||||
|  |         if (error) { | ||||||
|  |             LOG_WARNING(RPC_Server, "Failed to send reply: {}", error.message()); | ||||||
|  |         } else { | ||||||
|  |             LOG_INFO(RPC_Server, "Sent reply version({}) id=({}) type=({}) size=({})", | ||||||
|  |                      reply_packet.GetVersion(), reply_packet.GetId(), | ||||||
|  |                      static_cast<u32>(reply_packet.GetPacketType()), | ||||||
|  |                      reply_packet.GetPacketDataSize()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::thread worker_thread; | ||||||
|  |  | ||||||
|  |     boost::asio::io_context io_context; | ||||||
|  |     boost::asio::ip::udp::socket socket; | ||||||
|  |     std::array<u8, MAX_PACKET_SIZE> request_buffer; | ||||||
|  |     boost::asio::ip::udp::endpoint remote_endpoint; | ||||||
|  |  | ||||||
|  |     std::function<void(std::unique_ptr<Packet>)> new_request_callback; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | UDPServer::UDPServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback) | ||||||
|  |     : impl(std::make_unique<Impl>(new_request_callback)) {} | ||||||
|  |  | ||||||
|  | UDPServer::~UDPServer() = default; | ||||||
|  |  | ||||||
|  | } // namespace RPC | ||||||
							
								
								
									
										24
									
								
								src/core/rpc/udp_server.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								src/core/rpc/udp_server.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | // Copyright 2019 Citra Emulator Project | ||||||
|  | // Licensed under GPLv2 or any later version | ||||||
|  | // Refer to the license.txt file included. | ||||||
|  |  | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <functional> | ||||||
|  | #include <memory> | ||||||
|  |  | ||||||
|  | namespace RPC { | ||||||
|  |  | ||||||
|  | class Packet; | ||||||
|  |  | ||||||
|  | class UDPServer { | ||||||
|  | public: | ||||||
|  |     explicit UDPServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback); | ||||||
|  |     ~UDPServer(); | ||||||
|  |  | ||||||
|  | private: | ||||||
|  |     class Impl; | ||||||
|  |     std::unique_ptr<Impl> impl; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | } // namespace RPC | ||||||
| @@ -1,79 +0,0 @@ | |||||||
| #include "common/common_types.h" |  | ||||||
| #include "core/core.h" |  | ||||||
| #include "core/rpc/packet.h" |  | ||||||
| #include "core/rpc/zmq_server.h" |  | ||||||
|  |  | ||||||
| namespace RPC { |  | ||||||
|  |  | ||||||
| ZMQServer::ZMQServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback) |  | ||||||
|     : zmq_context(std::move(std::make_unique<zmq::context_t>(1))), |  | ||||||
|       zmq_socket(std::move(std::make_unique<zmq::socket_t>(*zmq_context, ZMQ_REP))), |  | ||||||
|       new_request_callback(std::move(new_request_callback)) { |  | ||||||
|     // Use a random high port |  | ||||||
|     // TODO: Make configurable or increment port number on failure |  | ||||||
|     zmq_socket->bind("tcp://127.0.0.1:45987"); |  | ||||||
|     LOG_INFO(RPC_Server, "ZeroMQ listening on port 45987"); |  | ||||||
|  |  | ||||||
|     worker_thread = std::thread(&ZMQServer::WorkerLoop, this); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| ZMQServer::~ZMQServer() { |  | ||||||
|     // Triggering the zmq_context destructor will cancel |  | ||||||
|     // any blocking calls to zmq_socket->recv() |  | ||||||
|     running = false; |  | ||||||
|     zmq_context.reset(); |  | ||||||
|     worker_thread.join(); |  | ||||||
|  |  | ||||||
|     LOG_INFO(RPC_Server, "ZeroMQ stopped"); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void ZMQServer::WorkerLoop() { |  | ||||||
|     zmq::message_t request; |  | ||||||
|     while (running) { |  | ||||||
|         try { |  | ||||||
|             if (zmq_socket->recv(&request, 0)) { |  | ||||||
|                 if (request.size() >= MIN_PACKET_SIZE && request.size() <= MAX_PACKET_SIZE) { |  | ||||||
|                     u8* request_buffer = static_cast<u8*>(request.data()); |  | ||||||
|                     PacketHeader header; |  | ||||||
|                     std::memcpy(&header, request_buffer, sizeof(header)); |  | ||||||
|                     if ((request.size() - MIN_PACKET_SIZE) == header.packet_size) { |  | ||||||
|                         u8* data = request_buffer + MIN_PACKET_SIZE; |  | ||||||
|                         std::function<void(Packet&)> send_reply_callback = |  | ||||||
|                             std::bind(&ZMQServer::SendReply, this, std::placeholders::_1); |  | ||||||
|                         std::unique_ptr<Packet> new_packet = |  | ||||||
|                             std::make_unique<Packet>(header, data, send_reply_callback); |  | ||||||
|  |  | ||||||
|                         // Send the request to the upper layer for handling |  | ||||||
|                         new_request_callback(std::move(new_packet)); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } catch (...) { |  | ||||||
|             LOG_WARNING(RPC_Server, "Failed to receive data on ZeroMQ socket"); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     std::unique_ptr<Packet> end_packet = nullptr; |  | ||||||
|     new_request_callback(std::move(end_packet)); |  | ||||||
|     // Destroying the socket must be done by this thread. |  | ||||||
|     zmq_socket.reset(); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void ZMQServer::SendReply(Packet& reply_packet) { |  | ||||||
|     if (running) { |  | ||||||
|         auto reply_buffer = |  | ||||||
|             std::make_unique<u8[]>(MIN_PACKET_SIZE + reply_packet.GetPacketDataSize()); |  | ||||||
|         auto reply_header = reply_packet.GetHeader(); |  | ||||||
|  |  | ||||||
|         std::memcpy(reply_buffer.get(), &reply_header, sizeof(reply_header)); |  | ||||||
|         std::memcpy(reply_buffer.get() + (4 * sizeof(u32)), reply_packet.GetPacketData().data(), |  | ||||||
|                     reply_packet.GetPacketDataSize()); |  | ||||||
|  |  | ||||||
|         zmq_socket->send(reply_buffer.get(), MIN_PACKET_SIZE + reply_packet.GetPacketDataSize()); |  | ||||||
|  |  | ||||||
|         LOG_INFO(RPC_Server, "Sent reply version({}) id=({}) type=({}) size=({})", |  | ||||||
|                  reply_packet.GetVersion(), reply_packet.GetId(), |  | ||||||
|                  static_cast<u32>(reply_packet.GetPacketType()), reply_packet.GetPacketDataSize()); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| }; // namespace RPC |  | ||||||
| @@ -1,34 +0,0 @@ | |||||||
| // Copyright 2018 Citra Emulator Project |  | ||||||
| // Licensed under GPLv2 or any later version |  | ||||||
| // Refer to the license.txt file included. |  | ||||||
|  |  | ||||||
| #pragma once |  | ||||||
|  |  | ||||||
| #include <functional> |  | ||||||
| #include <thread> |  | ||||||
| #define ZMQ_STATIC |  | ||||||
| #include <zmq.hpp> |  | ||||||
|  |  | ||||||
| namespace RPC { |  | ||||||
|  |  | ||||||
| class Packet; |  | ||||||
|  |  | ||||||
| class ZMQServer { |  | ||||||
| public: |  | ||||||
|     explicit ZMQServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback); |  | ||||||
|     ~ZMQServer(); |  | ||||||
|  |  | ||||||
| private: |  | ||||||
|     void WorkerLoop(); |  | ||||||
|     void SendReply(Packet& request); |  | ||||||
|  |  | ||||||
|     std::thread worker_thread; |  | ||||||
|     std::atomic_bool running = true; |  | ||||||
|  |  | ||||||
|     std::unique_ptr<zmq::context_t> zmq_context; |  | ||||||
|     std::unique_ptr<zmq::socket_t> zmq_socket; |  | ||||||
|  |  | ||||||
|     std::function<void(std::unique_ptr<Packet>)> new_request_callback; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| } // namespace RPC |  | ||||||
		Reference in New Issue
	
	Block a user