diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 26c9a0a28..566b5f36b 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -95,6 +95,10 @@ struct CommandEncoder { return enc_->setBytes(&v, sizeof(T), idx); } + void set_threadgroup_memory_length(size_t length, NS::UInteger index) { + enc_->setThreadgroupMemoryLength(length, index); + } + ConcurrentContext start_concurrent() { return ConcurrentContext(*this); }