mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
segfaut layer norm grad (#955)
This commit is contained in:
@@ -392,7 +392,7 @@ void LayerNormVJP::eval_gpu(
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
|
Reference in New Issue
Block a user