diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 5aa287603..c71795fad 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -244,8 +244,7 @@ void LayerNorm::eval_gpu( } }; - array o = set_output(inputs[0]); - const array& x = o.data_shared_ptr() ? o : out; + const array x = set_output(inputs[0]); const array& w = inputs[1]; const array& b = inputs[2]; diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 73bf216f0..3c521b90d 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -156,7 +156,6 @@ __global__ void rms_norm_vjp( cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; i++) { - // thread_g * thread_w * norm - thread_x * meangwx * norm3 float xi = xn[i]; float wi = wn[i]; float gi = gn[i]; @@ -214,8 +213,7 @@ void RMSNorm::eval_gpu( } }; - array o = set_output(inputs[0]); - const array& x = o.data_shared_ptr() ? o : out; + const array x = set_output(inputs[0]); const array& w = inputs[1]; int32_t axis_size = x.shape().back();