diff --git a/src/core/hle/kernel/address_arbiter.cpp b/src/core/hle/kernel/address_arbiter.cpp index 8475b698c..ebabde921 100644 --- a/src/core/hle/kernel/address_arbiter.cpp +++ b/src/core/hle/kernel/address_arbiter.cpp @@ -7,11 +7,15 @@ #include "common/assert.h" #include "common/common_types.h" +#include "core/arm/exclusive_monitor.h" #include "core/core.h" #include "core/hle/kernel/address_arbiter.h" #include "core/hle/kernel/errors.h" +#include "core/hle/kernel/handle_table.h" +#include "core/hle/kernel/kernel.h" #include "core/hle/kernel/scheduler.h" #include "core/hle/kernel/thread.h" +#include "core/hle/kernel/time_manager.h" #include "core/hle/result.h" #include "core/memory.h" @@ -20,6 +24,7 @@ namespace Kernel { // Wake up num_to_wake (or all) threads in a vector. void AddressArbiter::WakeThreads(const std::vector>& waiting_threads, s32 num_to_wake) { + auto& time_manager = system.Kernel().TimeManager(); // Only process up to 'target' threads, unless 'target' is <= 0, in which case process // them all. std::size_t last = waiting_threads.size(); @@ -29,12 +34,20 @@ void AddressArbiter::WakeThreads(const std::vector>& wai // Signal the waiting threads. for (std::size_t i = 0; i < last; i++) { + if (waiting_threads[i]->GetStatus() != ThreadStatus::WaitArb) { + last++; + last = std::min(waiting_threads.size(), last); + continue; + } + + time_manager.CancelTimeEvent(waiting_threads[i].get()); + ASSERT(waiting_threads[i]->GetStatus() == ThreadStatus::WaitArb); - waiting_threads[i]->SetWaitSynchronizationResult(RESULT_SUCCESS); + waiting_threads[i]->SetSynchronizationResults(nullptr, RESULT_SUCCESS); RemoveThread(waiting_threads[i]); + waiting_threads[i]->WaitForArbitration(false); waiting_threads[i]->SetArbiterWaitAddress(0); waiting_threads[i]->ResumeFromWait(); - system.PrepareReschedule(waiting_threads[i]->GetProcessorID()); } } @@ -56,6 +69,7 @@ ResultCode AddressArbiter::SignalToAddress(VAddr address, SignalType type, s32 v } ResultCode AddressArbiter::SignalToAddressOnly(VAddr address, s32 num_to_wake) { + SchedulerLock lock(system.Kernel()); const std::vector> waiting_threads = GetThreadsWaitingOnAddress(address); WakeThreads(waiting_threads, num_to_wake); @@ -64,6 +78,7 @@ ResultCode AddressArbiter::SignalToAddressOnly(VAddr address, s32 num_to_wake) { ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake) { + SchedulerLock lock(system.Kernel()); auto& memory = system.Memory(); // Ensure that we can write to the address. @@ -71,16 +86,25 @@ ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 return ERR_INVALID_ADDRESS_STATE; } - if (static_cast(memory.Read32(address)) != value) { - return ERR_INVALID_STATE; - } + const std::size_t current_core = system.CurrentCoreIndex(); + auto& monitor = system.Monitor(); + u32 current_value; + do { + monitor.SetExclusive(current_core, address); + current_value = memory.Read32(address); + + if (current_value != value) { + return ERR_INVALID_STATE; + } + current_value++; + } while (!monitor.ExclusiveWrite32(current_core, address, current_value)); - memory.Write32(address, static_cast(value + 1)); return SignalToAddressOnly(address, num_to_wake); } ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake) { + SchedulerLock lock(system.Kernel()); auto& memory = system.Memory(); // Ensure that we can write to the address. @@ -92,29 +116,34 @@ ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr a const std::vector> waiting_threads = GetThreadsWaitingOnAddress(address); - // Determine the modified value depending on the waiting count. + const std::size_t current_core = system.CurrentCoreIndex(); + auto& monitor = system.Monitor(); s32 updated_value; - if (num_to_wake <= 0) { - if (waiting_threads.empty()) { - updated_value = value + 1; - } else { - updated_value = value - 1; - } - } else { - if (waiting_threads.empty()) { - updated_value = value + 1; - } else if (waiting_threads.size() <= static_cast(num_to_wake)) { - updated_value = value - 1; - } else { - updated_value = value; - } - } + do { + monitor.SetExclusive(current_core, address); + updated_value = memory.Read32(address); - if (static_cast(memory.Read32(address)) != value) { - return ERR_INVALID_STATE; - } + if (updated_value != value) { + return ERR_INVALID_STATE; + } + // Determine the modified value depending on the waiting count. + if (num_to_wake <= 0) { + if (waiting_threads.empty()) { + updated_value = value + 1; + } else { + updated_value = value - 1; + } + } else { + if (waiting_threads.empty()) { + updated_value = value + 1; + } else if (waiting_threads.size() <= static_cast(num_to_wake)) { + updated_value = value - 1; + } else { + updated_value = value; + } + } + } while (!monitor.ExclusiveWrite32(current_core, address, updated_value)); - memory.Write32(address, static_cast(updated_value)); WakeThreads(waiting_threads, num_to_wake); return RESULT_SUCCESS; } @@ -136,60 +165,121 @@ ResultCode AddressArbiter::WaitForAddress(VAddr address, ArbitrationType type, s ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, bool should_decrement) { auto& memory = system.Memory(); + auto& kernel = system.Kernel(); + Thread* current_thread = system.CurrentScheduler().GetCurrentThread(); - // Ensure that we can read the address. - if (!memory.IsValidVirtualAddress(address)) { - return ERR_INVALID_ADDRESS_STATE; + Handle event_handle = InvalidHandle; + { + SchedulerLockAndSleep lock(kernel, event_handle, current_thread, timeout); + + // Ensure that we can read the address. + if (!memory.IsValidVirtualAddress(address)) { + lock.CancelSleep(); + return ERR_INVALID_ADDRESS_STATE; + } + + /// TODO(Blinkhawk): Check termination pending. + + s32 current_value = static_cast(memory.Read32(address)); + if (current_value >= value) { + lock.CancelSleep(); + return ERR_INVALID_STATE; + } + + s32 decrement_value; + + const std::size_t current_core = system.CurrentCoreIndex(); + auto& monitor = system.Monitor(); + do { + monitor.SetExclusive(current_core, address); + current_value = static_cast(memory.Read32(address)); + if (should_decrement) { + decrement_value = current_value - 1; + } else { + decrement_value = current_value; + } + } while ( + !monitor.ExclusiveWrite32(current_core, address, static_cast(decrement_value))); + + // Short-circuit without rescheduling, if timeout is zero. + if (timeout == 0) { + lock.CancelSleep(); + return RESULT_TIMEOUT; + } + + current_thread->SetSynchronizationResults(nullptr, RESULT_TIMEOUT); + current_thread->SetArbiterWaitAddress(address); + InsertThread(SharedFrom(current_thread)); + current_thread->SetStatus(ThreadStatus::WaitArb); + current_thread->WaitForArbitration(true); } - const s32 cur_value = static_cast(memory.Read32(address)); - if (cur_value >= value) { - return ERR_INVALID_STATE; + if (event_handle != InvalidHandle) { + auto& time_manager = kernel.TimeManager(); + time_manager.UnscheduleTimeEvent(event_handle); } - if (should_decrement) { - memory.Write32(address, static_cast(cur_value - 1)); + { + SchedulerLock lock(kernel); + if (current_thread->IsWaitingForArbitration()) { + RemoveThread(SharedFrom(current_thread)); + current_thread->WaitForArbitration(false); + } } - // Short-circuit without rescheduling, if timeout is zero. - if (timeout == 0) { - return RESULT_TIMEOUT; - } - - return WaitForAddressImpl(address, timeout); + return current_thread->GetSignalingResult(); } ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) { auto& memory = system.Memory(); - - // Ensure that we can read the address. - if (!memory.IsValidVirtualAddress(address)) { - return ERR_INVALID_ADDRESS_STATE; - } - - // Only wait for the address if equal. - if (static_cast(memory.Read32(address)) != value) { - return ERR_INVALID_STATE; - } - - // Short-circuit without rescheduling if timeout is zero. - if (timeout == 0) { - return RESULT_TIMEOUT; - } - - return WaitForAddressImpl(address, timeout); -} - -ResultCode AddressArbiter::WaitForAddressImpl(VAddr address, s64 timeout) { + auto& kernel = system.Kernel(); Thread* current_thread = system.CurrentScheduler().GetCurrentThread(); - current_thread->SetArbiterWaitAddress(address); - InsertThread(SharedFrom(current_thread)); - current_thread->SetStatus(ThreadStatus::WaitArb); - current_thread->InvalidateWakeupCallback(); - current_thread->WakeAfterDelay(timeout); - system.PrepareReschedule(current_thread->GetProcessorID()); - return RESULT_TIMEOUT; + Handle event_handle = InvalidHandle; + { + SchedulerLockAndSleep lock(kernel, event_handle, current_thread, timeout); + + // Ensure that we can read the address. + if (!memory.IsValidVirtualAddress(address)) { + lock.CancelSleep(); + return ERR_INVALID_ADDRESS_STATE; + } + + /// TODO(Blinkhawk): Check termination pending. + + s32 current_value = static_cast(memory.Read32(address)); + if (current_value != value) { + lock.CancelSleep(); + return ERR_INVALID_STATE; + } + + // Short-circuit without rescheduling, if timeout is zero. + if (timeout == 0) { + lock.CancelSleep(); + return RESULT_TIMEOUT; + } + + current_thread->SetSynchronizationResults(nullptr, RESULT_TIMEOUT); + current_thread->SetArbiterWaitAddress(address); + InsertThread(SharedFrom(current_thread)); + current_thread->SetStatus(ThreadStatus::WaitArb); + current_thread->WaitForArbitration(true); + } + + if (event_handle != InvalidHandle) { + auto& time_manager = kernel.TimeManager(); + time_manager.UnscheduleTimeEvent(event_handle); + } + + { + SchedulerLock lock(kernel); + if (current_thread->IsWaitingForArbitration()) { + RemoveThread(SharedFrom(current_thread)); + current_thread->WaitForArbitration(false); + } + } + + return current_thread->GetSignalingResult(); } void AddressArbiter::HandleWakeupThread(std::shared_ptr thread) { @@ -221,9 +311,9 @@ void AddressArbiter::RemoveThread(std::shared_ptr thread) { const auto iter = std::find_if(thread_list.cbegin(), thread_list.cend(), [&thread](const auto& entry) { return thread == entry; }); - ASSERT(iter != thread_list.cend()); - - thread_list.erase(iter); + if (iter != thread_list.cend()) { + thread_list.erase(iter); + } } std::vector> AddressArbiter::GetThreadsWaitingOnAddress( diff --git a/src/core/hle/kernel/address_arbiter.h b/src/core/hle/kernel/address_arbiter.h index f958eee5a..0b05d533c 100644 --- a/src/core/hle/kernel/address_arbiter.h +++ b/src/core/hle/kernel/address_arbiter.h @@ -73,9 +73,6 @@ private: /// Waits on an address if the value passed is equal to the argument value. ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout); - // Waits on the given address with a timeout in nanoseconds - ResultCode WaitForAddressImpl(VAddr address, s64 timeout); - /// Wake up num_to_wake (or all) threads in a vector. void WakeThreads(const std::vector>& waiting_threads, s32 num_to_wake); diff --git a/src/core/hle/kernel/svc.cpp b/src/core/hle/kernel/svc.cpp index 5e9dd43bf..718462b2b 100644 --- a/src/core/hle/kernel/svc.cpp +++ b/src/core/hle/kernel/svc.cpp @@ -1691,7 +1691,6 @@ static ResultCode WaitForAddress(Core::System& system, VAddr address, u32 type, LOG_TRACE(Kernel_SVC, "called, address=0x{:X}, type=0x{:X}, value=0x{:X}, timeout={}", address, type, value, timeout); - UNIMPLEMENTED(); // If the passed address is a kernel virtual address, return invalid memory state. if (Core::Memory::IsKernelVirtualAddress(address)) { LOG_ERROR(Kernel_SVC, "Address is a kernel virtual address, address={:016X}", address); @@ -1717,8 +1716,6 @@ static ResultCode SignalToAddress(Core::System& system, VAddr address, u32 type, LOG_TRACE(Kernel_SVC, "called, address=0x{:X}, type=0x{:X}, value=0x{:X}, num_to_wake=0x{:X}", address, type, value, num_to_wake); - UNIMPLEMENTED(); - // If the passed address is a kernel virtual address, return invalid memory state. if (Core::Memory::IsKernelVirtualAddress(address)) { LOG_ERROR(Kernel_SVC, "Address is a kernel virtual address, address={:016X}", address); diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index 7b6d1b4ec..e8355bbd1 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -548,6 +548,14 @@ public: return global_handle; } + bool IsWaitingForArbitration() const { + return waiting_for_arbitration; + } + + void WaitForArbitration(bool set) { + waiting_for_arbitration = set; + } + private: friend class GlobalScheduler; friend class Scheduler; @@ -615,6 +623,7 @@ private: /// If waiting for an AddressArbiter, this is the address being waited on. VAddr arb_wait_address{0}; + bool waiting_for_arbitration{}; /// Handle used as userdata to reference this object when inserting into the CoreTiming queue. Handle global_handle = 0;