commit 2354864dd64f1dab5e998f0e112cef7a619bc228 Author: louyu Date: Mon Jun 12 17:41:12 2023 +0800 first commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..7a997b1 --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +## 基于C++11 实现线程池 + +**项目描述:** + +1. 基于可变参模板编程和引用折叠原理,实现线程池submitTask接口,支持任意任务函数和任意参数的传递; +2. 使用future类型定制submitTask提交任务的返回值; +3. 使用map和queue容器管理线程对象和任务; +4. 基于条件变量condition_variable和互斥锁mutex实现任务提交线程和任务执行线程间的通信机制; +5. 支持fixed和cached模式的线程池定制。 + +**使用示例:** + +```c++ +#include +#include +#include "ThreadPool.h" // 引入头文件 + +using namespace std; + +int sum1(int a, int b) { + return a + b; +} + +int main() { + ThreadPool pool; // 定义线程池对象 + pool.start(); // 启动线程池 + future res = pool.submitTask(sum1, 10, 20); // 提交异步任务 + + cout << res.get() << endl; // 打印结果 + + return 0; +} +``` +更多进阶用法详见头文件中的注释说明。 \ No newline at end of file diff --git a/ThreadPool.h b/ThreadPool.h new file mode 100644 index 0000000..6274133 --- /dev/null +++ b/ThreadPool.h @@ -0,0 +1,272 @@ +// +// Created by louyu on 2023/6/12. +// + +#ifndef THREADPOOL_FINAL_THREADPOOL_H +#define THREADPOOL_FINAL_THREADPOOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const int TASK_MAX_THRESHOLD = 1024; +const int THREAD_MAX_THRESHOLD = 10; +const int THREAD_MAX_IDLE_TIME = 60; // 单位秒 + +// 线程池支持的模式 +enum class PoolMode { + MODE_FIXED, // 固定数量线程 + MODE_CACHED, // 线程数量可动态增长 +}; + +class Thread { +public: + using ThreadFunc = std::function; + Thread(ThreadFunc func): func_(func), threadId_(generateId_ ++) { + + } + + ~Thread() = default; + + void start() { // 启动线程 + std::thread t(func_, threadId_); // 创建一个线程执行一个线程函数 + t.detach(); // 设置分离线程,否则出作用域后线程对象销毁,线程函数也会中止 + } + + unsigned getId() const { + return threadId_; + } +private: + ThreadFunc func_; + static unsigned generateId_; + unsigned threadId_; // 线程id,会在线程池对象中建立id与线程对象的映射关系 +}; + +unsigned Thread::generateId_ = 0; + +// 线程池类型 +class ThreadPool { +public: + ThreadPool() + : initThreadSize_(0) + , taskSize_(0) + , idleThreadSize_(0) + , curThreadSize_(0) + , taskQueMaxThreshold_(TASK_MAX_THRESHOLD) + , threadSizeThreshold_(THREAD_MAX_THRESHOLD) + , poolMode_(PoolMode::MODE_FIXED) + , isPoolRunning_(false) { + + } + ~ThreadPool() { + isPoolRunning_ = false; + + // 等待线程池中所有的线程返回 + std::unique_lock lock(taskQueMtx_); + notEmpty_.notify_all(); // 线程池对象析构前,唤醒所有被阻塞的线程以返回结果 + exitCond_.wait(lock, [&]() {return threads_.empty();}); + } + + void setMode(PoolMode mode) { // 设置线程池的工作模式 + if(checkRunningState()) { // 若线程池已经在运行,禁止修改线程池模式 + return; + } + poolMode_ = mode; + } + + void setTaskQueMaxThreshold(unsigned threshold) { // 设置task任务队列任务上限 + if(checkRunningState()) { + return; + } + if(poolMode_ == PoolMode::MODE_CACHED) { + taskQueMaxThreshold_ = threshold; + } + } + + void setThreadSizeThreshold(unsigned threshold) { // 设置线程池cached模式下线程阈值 + if(checkRunningState()) { + return; + } + threadSizeThreshold_ = threshold; + } + + // 使用可变参模版编程,使得submitTask可以接收任意任务函数与任意数量的参数 + // pool.submitTask(sum, 10, 20); + // decltype可以根据括号内的表达式推导类型 + template + auto submitTask(Func &&func, Args &&...args) -> std::future { + // 打包任务,放入任务队列 + using RType = decltype(func(args...)); + auto task = std::make_shared>( + std::bind(std::forward(func), std::forward(args)...)); + std::future result = task->get_future(); + + // 获取锁,任务队列不是线程安全的 + std::unique_lock lock(taskQueMtx_); // unique_lock构造同时会获取锁 + + // 用条件变量等待任务队列有空余,wait一直等待直到后续条件满足 wait_for等待指定时间段 wait_until等待到某时间节点 + // 用户提交任务,最长阻塞不能超过1s,否则判断提交任务失败,返回 + if(!notFull_.wait_for(lock, std::chrono::seconds(1), [&]() { return taskQue_.size() < taskQueMaxThreshold_; })) {// 不满足lambda表达式条件时wait,释放锁 + // 若等待1s条件依然不满足,提交任务失败 + std::cerr << "task queue is full, submit task fail." << std::endl; + auto emptyTask = std::make_shared>( // 任务提交失败,运行一个空任务 + []() -> RType {return RType();}); + (*emptyTask)(); // 执行空任务,否则主线程调用future.get()会死锁 + return emptyTask->get_future(); + } + + // 若有空余,将任务放入任务队列中 + // using Task = std::function; + // 如何将一个带返回值的线程函数封装到std::function中,用lambda表达式 + taskQue_.emplace([task]() { + (*task)(); + }); + taskSize_ ++; + + // 新放入任务,任务队列必然不空,notEmpty_通知 + notEmpty_.notify_all(); + + // cached模式任务处理较为紧急,适合小且快的任务多场景,需要根据任务数量和空闲线程的数量,决定是否动态扩充线程数 + if(poolMode_ == PoolMode::MODE_CACHED + && taskSize_ > idleThreadSize_ + && curThreadSize_ < threadSizeThreshold_) { + // 创建新线程 + auto ptr = std::make_unique(std::bind(&ThreadPool::threadFunc, this, std::placeholders::_1)); + unsigned threadId = ptr->getId(); + threads_.emplace(threadId, std::move(ptr)); // unique_ptr拷贝构造被删除,要强转右值 + threads_[threadId]->start(); // 启动线程 + + // 修改线程个数相关变量 + curThreadSize_ ++; + idleThreadSize_ ++; + } + + // 返回Result对象 + return result; + } + + void start(unsigned initThreadSize = std::thread::hardware_concurrency()) { // 开启线程池,默认开启cpu核心数个线程 + isPoolRunning_ = true; // 设置线程池启动状态 + + initThreadSize_ = initThreadSize; // 初始化初始线程个数 + curThreadSize_ = initThreadSize; + + // 创建线程对象 + for(int i = 0; i < initThreadSize_; i ++) { + auto ptr = std::make_unique(std::bind(&ThreadPool::threadFunc, this, std::placeholders::_1)); + unsigned threadId = ptr->getId(); + threads_.emplace(threadId, std::move(ptr)); // unique_ptr拷贝构造被删除,要强转右值 + } + + // 启动所有线程 + for(int i = 0; i < initThreadSize_; i ++) { + threads_[i]->start(); + idleThreadSize_ ++; // 记录初始线程的数量 + } + } + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator=(const ThreadPool&) = delete; +private: + void threadFunc(unsigned threadId) { // 定义线程函数 + auto lastTime = std::chrono::high_resolution_clock().now(); // 记录该线程调度时间的变量 + + for(;;) { + Task task; + { + // 先获取锁 + std::unique_lock lock(taskQueMtx_); + + // 当任务队列中有任务的时候,不论线程池是否要析构,先要把任务做完 + while(taskQue_.empty()) { // 任务队列为空时 + + if(!isPoolRunning_) { // 唤醒后如果线程池要析构了,那么停止线程执行 + threads_.erase(threadId); // 线程结束前把线程对象从线程列表容器中删除 + exitCond_.notify_all(); // 通知线程池的析构函数,有线程析构 + return; + } + + // 否则,需要等待任务到来 + + // cached模式下,有可能需要回收之前创建的线程,即超过initThreadSize_数量的线程要进行回收 + // 当前时间 - 上一次线程执行时间 > 60s + if(poolMode_ == PoolMode::MODE_CACHED) { + if(std::cv_status::timeout == notEmpty_.wait_for(lock, std::chrono::seconds(1))) { // 条件变量超时返回 + auto now = std::chrono::high_resolution_clock().now(); + auto dur = std::chrono::duration_cast(now - lastTime); + if(dur.count() >= THREAD_MAX_IDLE_TIME && curThreadSize_ > initThreadSize_) { + // 开始回收线程 + threads_.erase(threadId); // 把线程对象从线程列表容器中删除 + + // 记录线程数量的相关变量的值修改 + curThreadSize_ --; + idleThreadSize_ --; + + return; + } + } + } else { // fixed模式 + // 等待notEmpty_条件 + notEmpty_.wait(lock); + } + } + + idleThreadSize_ --; // 空闲线程-- + + // 从任务队列取一个任务 + task = taskQue_.front(); + taskQue_.pop(); + taskSize_ --; + + // 如果依然有剩余任务,继续通知其它线程执行任务 + if(!taskQue_.empty()) { + notEmpty_.notify_all(); + } + + // 取出任务后必然不满,通知可以继续提交任务 + notFull_.notify_all(); + } // 释放锁,不能执行任务的时候还占着锁 + + // 当前线程负责执行这个任务 + if(task != nullptr) { + task(); // 执行任务function + } + idleThreadSize_ ++; // 任务执行完,空闲线程数量增加 + lastTime = std::chrono::high_resolution_clock().now(); // 更新该线程被调度执行完的时间 + } + } + + bool checkRunningState() const { + return isPoolRunning_; + } +private: + std::unordered_map> threads_; // 线程列表 + unsigned initThreadSize_; // 初始线程数量 + std::atomic curThreadSize_; // 记录当前线程池中线程总数量 + std::atomic idleThreadSize_; // 记录空闲线程的数量 + unsigned threadSizeThreshold_; // 线程数量上限阈值 + + // 这里队列里不能存裸指针,避免用户传入一个临时对象,使用智能指针延长外部传进来对象的生命周期 + using Task = std::function; + std::queue taskQue_; // 任务队列 + std::atomic taskSize_; // 任务队列任务数,用原子变量保证原子性 + unsigned taskQueMaxThreshold_; // 任务队列任务数上限 + + std::mutex taskQueMtx_; // 保证任务队列的线程安全 + std::condition_variable notFull_; // 任务队列不满 + std::condition_variable notEmpty_; // 任务队列不空 + std::condition_variable exitCond_; // 等待线程资源全部回收 + + PoolMode poolMode_; // 当前线程池的工作模式 + std::atomic isPoolRunning_; // 表示当前线程池的启动状态 +}; + +#endif //THREADPOOL_FINAL_THREADPOOL_H diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..868267a --- /dev/null +++ b/main.cpp @@ -0,0 +1,19 @@ +#include +#include +#include "ThreadPool.h" + +using namespace std; + +int sum1(int a, int b) { + return a + b; +} + +int main() { + ThreadPool pool; + pool.start(); + future res = pool.submitTask(sum1, 10, 20); + + cout << res.get() << endl; + + return 0; +}