From bdad00c73f46106ba78995bdde1b50349e940b09 Mon Sep 17 00:00:00 2001
From: Subv <subv2112@gmail.com>
Date: Sun, 4 Dec 2016 09:58:36 -0500
Subject: [PATCH] Threading: Added some utility functions and const
 correctness.

---
 src/citra_qt/debugger/wait_tree.cpp |  2 +-
 src/core/hle/kernel/kernel.cpp      | 13 ++++++-------
 src/core/hle/kernel/thread.h        | 19 ++++++++++++++++---
 src/core/hle/svc.cpp                | 18 +++++++++++++-----
 4 files changed, 36 insertions(+), 16 deletions(-)

diff --git a/src/citra_qt/debugger/wait_tree.cpp b/src/citra_qt/debugger/wait_tree.cpp
index 8fc3e37e0..829ac7dd6 100644
--- a/src/citra_qt/debugger/wait_tree.cpp
+++ b/src/citra_qt/debugger/wait_tree.cpp
@@ -230,7 +230,7 @@ std::vector<std::unique_ptr<WaitTreeItem>> WaitTreeThread::GetChildren() const {
         list.push_back(std::make_unique<WaitTreeMutexList>(thread.held_mutexes));
     }
     if (thread.status == THREADSTATUS_WAIT_SYNCH) {
-        list.push_back(std::make_unique<WaitTreeObjectList>(thread.wait_objects, !thread.wait_objects.empty()));
+        list.push_back(std::make_unique<WaitTreeObjectList>(thread.wait_objects, thread.IsWaitingAll()));
     }
 
     return list;
diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp
index be7a5a6d8..6d358def7 100644
--- a/src/core/hle/kernel/kernel.cpp
+++ b/src/core/hle/kernel/kernel.cpp
@@ -33,7 +33,7 @@ void WaitObject::RemoveWaitingThread(Thread* thread) {
 
 SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
     // Remove the threads that are ready or already running from our waitlist
-    waiting_threads.erase(std::remove_if(waiting_threads.begin(), waiting_threads.end(), [](SharedPtr<Thread> thread) -> bool {
+    waiting_threads.erase(std::remove_if(waiting_threads.begin(), waiting_threads.end(), [](const SharedPtr<Thread>& thread) -> bool {
         return thread->status == THREADSTATUS_RUNNING || thread->status == THREADSTATUS_READY;
     }), waiting_threads.end());
 
@@ -42,12 +42,11 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
 
     auto candidate_threads = waiting_threads;
 
-    // Eliminate all threads that are waiting on more than one object, and not all of them are ready
-    candidate_threads.erase(std::remove_if(candidate_threads.begin(), candidate_threads.end(), [](SharedPtr<Thread> thread) -> bool {
-        for (auto object : thread->wait_objects)
-            if (object->ShouldWait())
-                return true;
-        return false;
+    // Eliminate all threads that are waiting on more than one object, and not all of said objects are ready
+    candidate_threads.erase(std::remove_if(candidate_threads.begin(), candidate_threads.end(), [](const SharedPtr<Thread>& thread) -> bool {
+        return std::any_of(thread->wait_objects.begin(), thread->wait_objects.end(), [](const SharedPtr<WaitObject>& object) -> bool {
+            return object->ShouldWait();
+        });
     }), candidate_threads.end());
 
     // Return the thread with the lowest priority value (The one with the highest priority)
diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h
index 63b97b74f..1b29fb3a3 100644
--- a/src/core/hle/kernel/thread.h
+++ b/src/core/hle/kernel/thread.h
@@ -131,8 +131,8 @@ public:
      * It is used to set the output value of WaitSynchronizationN when the thread is awakened.
      * @param object Object to query the index of.
      */
-    s32 GetWaitObjectIndex(WaitObject* object) {
-        return wait_objects_index[object->GetObjectId()];
+    s32 GetWaitObjectIndex(const WaitObject* object) const {
+        return wait_objects_index.at(object->GetObjectId());
     }
 
     /**
@@ -148,6 +148,15 @@ public:
         return tls_address;
     }
 
+    /**
+     * Returns whether this thread is waiting for all the objects in
+     * its wait list to become ready, as a result of a WaitSynchronizationN call
+     * with wait_all = true, or a ReplyAndReceive call.
+     */
+    bool IsWaitingAll() const {
+        return !wait_objects.empty();
+    }
+
     Core::ThreadContext context;
 
     u32 thread_id;
@@ -169,7 +178,11 @@ public:
     boost::container::flat_set<SharedPtr<Mutex>> held_mutexes;
 
     SharedPtr<Process> owner_process;                ///< Process that owns this thread
-    std::vector<SharedPtr<WaitObject>> wait_objects; ///< Objects that the thread is waiting on
+
+    /// Objects that the thread is waiting on.
+    /// This is only populated when the thread should wait for all the objects to become ready.
+    std::vector<SharedPtr<WaitObject>> wait_objects;
+
     std::unordered_map<int, s32> wait_objects_index; ///< Mapping of Object ids to their position in the last waitlist that this object waited on.
 
     VAddr wait_address;   ///< If waiting on an AddressArbiter, this is the arbitration address
diff --git a/src/core/hle/svc.cpp b/src/core/hle/svc.cpp
index 061692af8..c06df84b3 100644
--- a/src/core/hle/svc.cpp
+++ b/src/core/hle/svc.cpp
@@ -257,18 +257,21 @@ static ResultCode WaitSynchronization1(Handle handle, s64 nano_seconds) {
 
     if (object->ShouldWait()) {
 
-        if (nano_seconds == 0)
+        if (nano_seconds == 0) {
             return ResultCode(ErrorDescription::Timeout, ErrorModule::OS,
                               ErrorSummary::StatusChanged,
                               ErrorLevel::Info);
+        }
 
         object->AddWaitingThread(thread);
+        // TODO(Subv): Perform things like update the mutex lock owner's priority to prevent priority inversion.
+        // Currently this is done in Mutex::ShouldWait, but it should be moved to a function that is called from here.
         thread->status = THREADSTATUS_WAIT_SYNCH;
 
         // Create an event to wake the thread up after the specified nanosecond delay has passed
         thread->WakeAfterDelay(nano_seconds);
 
-        // Note: The output of this SVC will be set to RESULT_SUCCESS if the thread resumes due to a signal in one of its wait objects.
+        // Note: The output of this SVC will be set to RESULT_SUCCESS if the thread resumes due to a signal in its wait objects.
         // Otherwise we retain the default value of timeout.
         return ResultCode(ErrorDescription::Timeout, ErrorModule::OS,
                                ErrorSummary::StatusChanged,
@@ -312,7 +315,9 @@ static ResultCode WaitSynchronizationN(s32* out, Handle* handles, s32 handle_cou
         objects[i] = object;
     }
 
-    // Clear the mapping of wait object indices
+    // Clear the mapping of wait object indices.
+    // We don't want any lingering state in this map.
+    // It will be repopulated later in the wait_all = false case.
     thread->wait_objects_index.clear();
 
     if (!wait_all) {
@@ -345,12 +350,13 @@ static ResultCode WaitSynchronizationN(s32* out, Handle* handles, s32 handle_cou
         thread->wait_objects.clear();
 
         // Add the thread to each of the objects' waiting threads.
-        for (int i = 0; i < objects.size(); ++i) {
+        for (size_t i = 0; i < objects.size(); ++i) {
             ObjectPtr object = objects[i];
             // Set the index of this object in the mapping of Objects -> index for this thread.
-            thread->wait_objects_index[object->GetObjectId()] = i;
+            thread->wait_objects_index[object->GetObjectId()] = static_cast<int>(i);
             object->AddWaitingThread(thread);
             // TODO(Subv): Perform things like update the mutex lock owner's priority to prevent priority inversion.
+            // Currently this is done in Mutex::ShouldWait, but it should be moved to a function that is called from here.
         }
 
         // Note: If no handles and no timeout were given, then the thread will deadlock, this is consistent with hardware behavior.
@@ -396,6 +402,7 @@ static ResultCode WaitSynchronizationN(s32* out, Handle* handles, s32 handle_cou
         for (auto object : objects) {
             object->AddWaitingThread(thread);
             // TODO(Subv): Perform things like update the mutex lock owner's priority to prevent priority inversion.
+            // Currently this is done in Mutex::ShouldWait, but it should be moved to a function that is called from here.
         }
 
         // Create an event to wake the thread up after the specified nanosecond delay has passed
@@ -1172,6 +1179,7 @@ void CallSVC(u32 immediate) {
     if (info) {
         if (info->func) {
             info->func();
+            //  TODO(Subv): Not all service functions should cause a reschedule in all cases.
             HLE::Reschedule(__func__);
         } else {
             LOG_ERROR(Kernel_SVC, "unimplemented SVC function %s(..)", info->name);