Merge pull request #4849 from ReinUsesLisp/fix-fiber-test
tests: Fix data race in fibers test
This commit is contained in:
		| @@ -6,18 +6,40 @@ | ||||
| #include <cstdlib> | ||||
| #include <functional> | ||||
| #include <memory> | ||||
| #include <mutex> | ||||
| #include <stdexcept> | ||||
| #include <thread> | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
|  | ||||
| #include <catch2/catch.hpp> | ||||
| #include <math.h> | ||||
|  | ||||
| #include "common/common_types.h" | ||||
| #include "common/fiber.h" | ||||
| #include "common/spin_lock.h" | ||||
|  | ||||
| namespace Common { | ||||
|  | ||||
| class ThreadIds { | ||||
| public: | ||||
|     void Register(u32 id) { | ||||
|         const auto thread_id = std::this_thread::get_id(); | ||||
|         std::scoped_lock lock{mutex}; | ||||
|         if (ids.contains(thread_id)) { | ||||
|             throw std::logic_error{"Registering the same thread twice"}; | ||||
|         } | ||||
|         ids.emplace(thread_id, id); | ||||
|     } | ||||
|  | ||||
|     [[nodiscard]] u32 Get() const { | ||||
|         std::scoped_lock lock{mutex}; | ||||
|         return ids.at(std::this_thread::get_id()); | ||||
|     } | ||||
|  | ||||
| private: | ||||
|     mutable std::mutex mutex; | ||||
|     std::unordered_map<std::thread::id, u32> ids; | ||||
| }; | ||||
|  | ||||
| class TestControl1 { | ||||
| public: | ||||
|     TestControl1() = default; | ||||
| @@ -26,7 +48,7 @@ public: | ||||
|  | ||||
|     void ExecuteThread(u32 id); | ||||
|  | ||||
|     std::unordered_map<std::thread::id, u32> ids; | ||||
|     ThreadIds thread_ids; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> work_fibers; | ||||
|     std::vector<u32> items; | ||||
| @@ -39,8 +61,7 @@ static void WorkControl1(void* control) { | ||||
| } | ||||
|  | ||||
| void TestControl1::DoWork() { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     u32 id = ids[this_id]; | ||||
|     const u32 id = thread_ids.Get(); | ||||
|     u32 value = items[id]; | ||||
|     for (u32 i = 0; i < id; i++) { | ||||
|         value++; | ||||
| @@ -50,8 +71,7 @@ void TestControl1::DoWork() { | ||||
| } | ||||
|  | ||||
| void TestControl1::ExecuteThread(u32 id) { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     ids[this_id] = id; | ||||
|     thread_ids.Register(id); | ||||
|     auto thread_fiber = Fiber::ThreadToFiber(); | ||||
|     thread_fibers[id] = thread_fiber; | ||||
|     work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this); | ||||
| @@ -98,8 +118,7 @@ public: | ||||
|             value1 += i; | ||||
|         } | ||||
|         Fiber::YieldTo(fiber1, fiber3); | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         const u32 id = thread_ids.Get(); | ||||
|         assert1 = id == 1; | ||||
|         value2 += 5000; | ||||
|         Fiber::YieldTo(fiber1, thread_fibers[id]); | ||||
| @@ -115,8 +134,7 @@ public: | ||||
|     } | ||||
|  | ||||
|     void DoWork3() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         const u32 id = thread_ids.Get(); | ||||
|         assert2 = id == 0; | ||||
|         value1 += 1000; | ||||
|         Fiber::YieldTo(fiber3, thread_fibers[id]); | ||||
| @@ -125,14 +143,12 @@ public: | ||||
|     void ExecuteThread(u32 id); | ||||
|  | ||||
|     void CallFiber1() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         const u32 id = thread_ids.Get(); | ||||
|         Fiber::YieldTo(thread_fibers[id], fiber1); | ||||
|     } | ||||
|  | ||||
|     void CallFiber2() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         const u32 id = thread_ids.Get(); | ||||
|         Fiber::YieldTo(thread_fibers[id], fiber2); | ||||
|     } | ||||
|  | ||||
| @@ -145,7 +161,7 @@ public: | ||||
|     u32 value2{}; | ||||
|     std::atomic<bool> trap{true}; | ||||
|     std::atomic<bool> trap2{true}; | ||||
|     std::unordered_map<std::thread::id, u32> ids; | ||||
|     ThreadIds thread_ids; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; | ||||
|     std::shared_ptr<Common::Fiber> fiber1; | ||||
|     std::shared_ptr<Common::Fiber> fiber2; | ||||
| @@ -168,15 +184,13 @@ static void WorkControl2_3(void* control) { | ||||
| } | ||||
|  | ||||
| void TestControl2::ExecuteThread(u32 id) { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     ids[this_id] = id; | ||||
|     thread_ids.Register(id); | ||||
|     auto thread_fiber = Fiber::ThreadToFiber(); | ||||
|     thread_fibers[id] = thread_fiber; | ||||
| } | ||||
|  | ||||
| void TestControl2::Exit() { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     u32 id = ids[this_id]; | ||||
|     const u32 id = thread_ids.Get(); | ||||
|     thread_fibers[id]->Exit(); | ||||
| } | ||||
|  | ||||
| @@ -228,24 +242,21 @@ public: | ||||
|     void DoWork1() { | ||||
|         value1 += 1; | ||||
|         Fiber::YieldTo(fiber1, fiber2); | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         const u32 id = thread_ids.Get(); | ||||
|         value3 += 1; | ||||
|         Fiber::YieldTo(fiber1, thread_fibers[id]); | ||||
|     } | ||||
|  | ||||
|     void DoWork2() { | ||||
|         value2 += 1; | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         const u32 id = thread_ids.Get(); | ||||
|         Fiber::YieldTo(fiber2, thread_fibers[id]); | ||||
|     } | ||||
|  | ||||
|     void ExecuteThread(u32 id); | ||||
|  | ||||
|     void CallFiber1() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         const u32 id = thread_ids.Get(); | ||||
|         Fiber::YieldTo(thread_fibers[id], fiber1); | ||||
|     } | ||||
|  | ||||
| @@ -254,7 +265,7 @@ public: | ||||
|     u32 value1{}; | ||||
|     u32 value2{}; | ||||
|     u32 value3{}; | ||||
|     std::unordered_map<std::thread::id, u32> ids; | ||||
|     ThreadIds thread_ids; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; | ||||
|     std::shared_ptr<Common::Fiber> fiber1; | ||||
|     std::shared_ptr<Common::Fiber> fiber2; | ||||
| @@ -271,15 +282,13 @@ static void WorkControl3_2(void* control) { | ||||
| } | ||||
|  | ||||
| void TestControl3::ExecuteThread(u32 id) { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     ids[this_id] = id; | ||||
|     thread_ids.Register(id); | ||||
|     auto thread_fiber = Fiber::ThreadToFiber(); | ||||
|     thread_fibers[id] = thread_fiber; | ||||
| } | ||||
|  | ||||
| void TestControl3::Exit() { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     u32 id = ids[this_id]; | ||||
|     const u32 id = thread_ids.Get(); | ||||
|     thread_fibers[id]->Exit(); | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user