segfaut layer norm grad (#955)

This commit is contained in:
Awni Hannun
2024-04-04 10:59:15 -07:00
committed by GitHub
parent e142aaf8a1
commit d88d2124b5
2 changed files with 12 additions and 1 deletions

View File

@@ -392,7 +392,7 @@ void LayerNormVJP::eval_gpu(
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
set_array_buffer(compute_encoder, w, 1);