Check for valid launch parameters

This commit is contained in:
Angelos Katharopoulos
2025-06-06 11:41:56 -07:00
parent 570dd8287a
commit 97bd67c032

View File

@@ -273,7 +273,13 @@ void LayerNorm::eval_gpu(
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
std::ostringstream msg;
msg << "[layer_norm] Threadgroup size " << threadgroup_size
<< " is larger than the maximum allowed threadgroup size "
<< kernel->maxTotalThreadsPerThreadgroup();
throw std::runtime_error(msg.str());
}
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
@@ -396,7 +402,13 @@ void LayerNormVJP::eval_gpu(
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
std::ostringstream msg;
msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size
<< " is larger than the maximum allowed threadgroup size "
<< kernel->maxTotalThreadsPerThreadgroup();
throw std::runtime_error(msg.str());
}
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);