diff --git a/io_uring/io-wq.c b/io_uring/io-wq.c index 2da0b1ba6a56..62f345587df5 100644 --- a/io_uring/io-wq.c +++ b/io_uring/io-wq.c @@ -1306,13 +1306,16 @@ static int io_wq_cpu_offline(unsigned int cpu, struct hlist_node *node) return __io_wq_cpu_online(wq, cpu, false); } -int io_wq_cpu_affinity(struct io_wq *wq, cpumask_var_t mask) +int io_wq_cpu_affinity(struct io_uring_task *tctx, cpumask_var_t mask) { + if (!tctx || !tctx->io_wq) + return -EINVAL; + rcu_read_lock(); if (mask) - cpumask_copy(wq->cpu_mask, mask); + cpumask_copy(tctx->io_wq->cpu_mask, mask); else - cpumask_copy(wq->cpu_mask, cpu_possible_mask); + cpumask_copy(tctx->io_wq->cpu_mask, cpu_possible_mask); rcu_read_unlock(); return 0; diff --git a/io_uring/io-wq.h b/io_uring/io-wq.h index 31228426d192..06d9ca90c577 100644 --- a/io_uring/io-wq.h +++ b/io_uring/io-wq.h @@ -50,7 +50,7 @@ void io_wq_put_and_exit(struct io_wq *wq); void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work); void io_wq_hash_work(struct io_wq_work *work, void *val); -int io_wq_cpu_affinity(struct io_wq *wq, cpumask_var_t mask); +int io_wq_cpu_affinity(struct io_uring_task *tctx, cpumask_var_t mask); int io_wq_max_workers(struct io_wq *wq, int *new_count); static inline bool io_wq_is_hashed(struct io_wq_work *work) diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c index e189158ebbdd..e1a23f4993d3 100644 --- a/io_uring/io_uring.c +++ b/io_uring/io_uring.c @@ -4183,16 +4183,28 @@ static int io_register_enable_rings(struct io_ring_ctx *ctx) return 0; } +static __cold int __io_register_iowq_aff(struct io_ring_ctx *ctx, + cpumask_var_t new_mask) +{ + int ret; + + if (!(ctx->flags & IORING_SETUP_SQPOLL)) { + ret = io_wq_cpu_affinity(current->io_uring, new_mask); + } else { + mutex_unlock(&ctx->uring_lock); + ret = io_sqpoll_wq_cpu_affinity(ctx, new_mask); + mutex_lock(&ctx->uring_lock); + } + + return ret; +} + static __cold int io_register_iowq_aff(struct io_ring_ctx *ctx, void __user *arg, unsigned len) { - struct io_uring_task *tctx = current->io_uring; cpumask_var_t new_mask; int ret; - if (!tctx || !tctx->io_wq) - return -EINVAL; - if (!alloc_cpumask_var(&new_mask, GFP_KERNEL)) return -ENOMEM; @@ -4213,19 +4225,14 @@ static __cold int io_register_iowq_aff(struct io_ring_ctx *ctx, return -EFAULT; } - ret = io_wq_cpu_affinity(tctx->io_wq, new_mask); + ret = __io_register_iowq_aff(ctx, new_mask); free_cpumask_var(new_mask); return ret; } static __cold int io_unregister_iowq_aff(struct io_ring_ctx *ctx) { - struct io_uring_task *tctx = current->io_uring; - - if (!tctx || !tctx->io_wq) - return -EINVAL; - - return io_wq_cpu_affinity(tctx->io_wq, NULL); + return __io_register_iowq_aff(ctx, NULL); } static __cold int io_register_iowq_max_workers(struct io_ring_ctx *ctx, diff --git a/io_uring/sqpoll.c b/io_uring/sqpoll.c index 5e329e3cd470..ee2d2c687fda 100644 --- a/io_uring/sqpoll.c +++ b/io_uring/sqpoll.c @@ -421,3 +421,18 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx, io_sq_thread_finish(ctx); return ret; } + +__cold int io_sqpoll_wq_cpu_affinity(struct io_ring_ctx *ctx, + cpumask_var_t mask) +{ + struct io_sq_data *sqd = ctx->sq_data; + int ret = -EINVAL; + + if (sqd) { + io_sq_thread_park(sqd); + ret = io_wq_cpu_affinity(sqd->thread->io_uring, mask); + io_sq_thread_unpark(sqd); + } + + return ret; +} diff --git a/io_uring/sqpoll.h b/io_uring/sqpoll.h index e1b8d508d22d..8df37e8c9149 100644 --- a/io_uring/sqpoll.h +++ b/io_uring/sqpoll.h @@ -27,3 +27,4 @@ void io_sq_thread_park(struct io_sq_data *sqd); void io_sq_thread_unpark(struct io_sq_data *sqd); void io_put_sq_data(struct io_sq_data *sqd); void io_sqpoll_wait_sq(struct io_ring_ctx *ctx); +int io_sqpoll_wq_cpu_affinity(struct io_ring_ctx *ctx, cpumask_var_t mask);