diff options
Diffstat (limited to 'src/common/thread_worker.h')
-rw-r--r-- | src/common/thread_worker.h | 103 |
1 files changed, 95 insertions, 8 deletions
diff --git a/src/common/thread_worker.h b/src/common/thread_worker.h index f1859971f..cd0017726 100644 --- a/src/common/thread_worker.h +++ b/src/common/thread_worker.h @@ -5,26 +5,113 @@ #pragma once #include <atomic> +#include <condition_variable> #include <functional> #include <mutex> +#include <stop_token> #include <string> +#include <thread> +#include <type_traits> #include <vector> #include <queue> +#include "common/thread.h" +#include "common/unique_function.h" + namespace Common { -class ThreadWorker final { +template <class StateType = void> +class StatefulThreadWorker { + static constexpr bool with_state = !std::is_same_v<StateType, void>; + + struct DummyCallable { + int operator()() const noexcept { + return 0; + } + }; + + using Task = + std::conditional_t<with_state, UniqueFunction<void, StateType*>, UniqueFunction<void>>; + using StateMaker = std::conditional_t<with_state, std::function<StateType()>, DummyCallable>; + public: - explicit ThreadWorker(std::size_t num_workers, const std::string& name); - ~ThreadWorker(); - void QueueWork(std::function<void()>&& work); + explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {}) + : workers_queued{num_workers}, thread_name{std::move(name)} { + const auto lambda = [this, func](std::stop_token stop_token) { + Common::SetCurrentThreadName(thread_name.c_str()); + { + [[maybe_unused]] std::conditional_t<with_state, StateType, int> state{func()}; + while (!stop_token.stop_requested()) { + Task task; + { + std::unique_lock lock{queue_mutex}; + if (requests.empty()) { + wait_condition.notify_all(); + } + condition.wait(lock, stop_token, [this] { return !requests.empty(); }); + if (stop_token.stop_requested()) { + break; + } + task = std::move(requests.front()); + requests.pop(); + } + if constexpr (with_state) { + task(&state); + } else { + task(); + } + ++work_done; + } + } + ++workers_stopped; + wait_condition.notify_all(); + }; + threads.reserve(num_workers); + for (size_t i = 0; i < num_workers; ++i) { + threads.emplace_back(lambda); + } + } + + StatefulThreadWorker& operator=(const StatefulThreadWorker&) = delete; + StatefulThreadWorker(const StatefulThreadWorker&) = delete; + + StatefulThreadWorker& operator=(StatefulThreadWorker&&) = delete; + StatefulThreadWorker(StatefulThreadWorker&&) = delete; + + void QueueWork(Task work) { + { + std::unique_lock lock{queue_mutex}; + requests.emplace(std::move(work)); + ++work_scheduled; + } + condition.notify_one(); + } + + void WaitForRequests(std::stop_token stop_token = {}) { + std::stop_callback callback(stop_token, [this] { + for (auto& thread : threads) { + thread.request_stop(); + } + }); + std::unique_lock lock{queue_mutex}; + wait_condition.wait(lock, [this] { + return workers_stopped >= workers_queued || work_done >= work_scheduled; + }); + } private: - std::vector<std::thread> threads; - std::queue<std::function<void()>> requests; + std::queue<Task> requests; std::mutex queue_mutex; - std::condition_variable condition; - std::atomic_bool stop{}; + std::condition_variable_any condition; + std::condition_variable wait_condition; + std::atomic<size_t> work_scheduled{}; + std::atomic<size_t> work_done{}; + std::atomic<size_t> workers_stopped{}; + std::atomic<size_t> workers_queued{}; + std::string thread_name; + std::vector<std::jthread> threads; }; +using ThreadWorker = StatefulThreadWorker<>; + } // namespace Common |