diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 6a1538905..c53289828 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -273,7 +273,13 @@ void LayerNorm::eval_gpu( size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); @@ -396,7 +402,13 @@ void LayerNormVJP::eval_gpu( size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1);