diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index 3b4c0a30a..165868401 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -28,7 +28,7 @@ struct GEMVKernel { static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up - // into blocks of (BM * TM, BN * TN) divided amoung threadgroups + // into blocks of (BM * TM, BN * TN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each thead group is launched with (BN, BM, 1) threads // @@ -166,7 +166,7 @@ template < struct GEMVTKernel { // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up - // into blocks of (BM * TM, BN * TN) divided amoung threadgroups + // into blocks of (BM * TM, BN * TN) divided among threadgroups // - Every thread works on a block of (TM, TN) // - We assume each thead group is launched with (BN, BM, 1) threads //