diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 84a0dd04e0..c9b0067c8a 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -28,7 +28,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = Op{}(a[0], b[0]); + out_vec[i] = Op{}(a[0], b[0]); } store_vector(out, index, out_vec); @@ -49,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = Op{}(a[0], b_vec.val[i]); + out_vec[i] = Op{}(a[0], b_vec[i]); } store_vector(out, index, out_vec); @@ -70,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = Op{}(a_vec.val[i], b[0]); + out_vec[i] = Op{}(a_vec[i], b[0]); } store_vector(out, index, out_vec); @@ -92,7 +92,7 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]); + out_vec[i] = Op{}(a_vec[i], b_vec[i]); } store_vector(out, index, out_vec); @@ -248,8 +248,7 @@ void binary_op_gpu_inplace( } else { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - // TODO: Choose optimized value based on type size. - constexpr int N_READS = 4; + constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index dfcd81347a..598924c098 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -33,8 +33,8 @@ binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a[0], b[0]); - out_a_vec.val[i] = out[0]; - out_b_vec.val[i] = out[1]; + out_a_vec[i] = out[0]; + out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); @@ -60,9 +60,9 @@ binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - auto out = Op{}(a[0], b_vec.val[i]); - out_a_vec.val[i] = out[0]; - out_b_vec.val[i] = out[1]; + auto out = Op{}(a[0], b_vec[i]); + out_a_vec[i] = out[0]; + out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); @@ -88,9 +88,9 @@ binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - auto out = Op{}(a_vec.val[i], b[0]); - out_a_vec.val[i] = out[0]; - out_b_vec.val[i] = out[1]; + auto out = Op{}(a_vec[i], b[0]); + out_a_vec[i] = out[0]; + out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); @@ -117,9 +117,9 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - auto out = Op{}(a_vec.val[i], b_vec.val[i]); - out_a_vec.val[i] = out[0]; - out_b_vec.val[i] = out[1]; + auto out = Op{}(a_vec[i], b_vec[i]); + out_a_vec[i] = out[0]; + out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); @@ -270,8 +270,7 @@ void binary_two_op_gpu_inplace( } else { dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - // TODO: Choose optimized value based on type size. - constexpr int N_READS = 4; + constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::binary_two_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_two_sv; diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 4e9eaccb7e..8ac0533f37 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -22,7 +22,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = cast_to(in[0]); + out_vec[i] = cast_to(in[0]); } store_vector(out, index, out_vec); @@ -43,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = cast_to(in_vec.val[i]); + out_vec[i] = cast_to(in_vec[i]); } store_vector(out, index, out_vec); @@ -65,8 +65,7 @@ void copy_contiguous( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - // TODO: Choose optimized value based on type size. - constexpr int N_READS = 4; + constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::copy_s; if (ctype == CopyType::Vector) { kernel = cu::copy_v; diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index c5ae14b382..8dd8f02762 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -32,21 +32,103 @@ using Strides = cuda::std::array; template struct alignas(sizeof(T) * N) AlignedVector { T val[N]; + + __device__ T& operator[](int i) { + return val[i]; + } + + __device__ T operator[](int i) const { + return val[i]; + } }; +template +inline __device__ bool is_aligned(T* x) { + return (reinterpret_cast(x) % (N * sizeof(T))) == 0; +} + template inline __device__ AlignedVector load_vector( const T* ptr, uint32_t offset) { - auto* from = reinterpret_cast*>(ptr); - return from[offset]; + if (is_aligned(ptr)) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; + } else { + AlignedVector v; +#pragma unroll + for (int i = 0; i < N; ++i) { + v[i] = ptr[offset * N + i]; + } + return v; + } +} + +template +inline __device__ AlignedVector +load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) { + if (is_aligned(ptr) && (offset + 1) * N <= size) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; + } else { + AlignedVector v; +#pragma unroll + for (int i = 0; i < N; ++i) { + v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback; + } + return v; + } +} + +template +inline __device__ AlignedVector load_vector( + const T* ptr, + uint32_t offset, + SizeT size, + int64_t stride, + T fallback) { + if (is_aligned(ptr) && stride == 1 && (offset + 1) * N <= size) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; + } else { + AlignedVector v; +#pragma unroll + for (int i = 0; i < N; ++i) { + v[i] = + (N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback; + } + return v; + } } template inline __device__ void store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { - auto* to = reinterpret_cast*>(ptr); - to[offset] = vec; + if (is_aligned(ptr)) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { +#pragma unroll + for (int i = 0; i < N; ++i) { + ptr[offset * N + i] = vec[i]; + } + } +} + +template +inline __device__ void store_vector( + T* ptr, + uint32_t offset, + const AlignedVector& vec, + SizeT size) { + if (is_aligned(ptr) && (offset + 1) * N <= size) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { + for (int i = 0; (offset * N + i) < size && i < N; ++i) { + ptr[offset * N + i] = vec[i]; + } + } } // Helper for accessing strided data. diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index 9aaaa2541c..163945e79a 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -31,8 +31,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { auto local_vec = load_vector(vec + col, 0); #pragma unroll for (int j = 0; j < n_per_thread; ++j) { - sum += static_cast(local_mat.val[j]) * - static_cast(local_vec.val[j]); + sum += + static_cast(local_mat[j]) * static_cast(local_vec[j]); } } @@ -73,8 +73,7 @@ __global__ void gemv_batched( } bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { - bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0; - return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); + return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); } template diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index fdb63d64c2..d0d0f80c82 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -10,8 +10,6 @@ #include #include #include -#include -#include namespace mlx::core { @@ -74,9 +72,11 @@ __global__ void layer_norm( float sum = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS] = {}; - cub::LoadDirectBlocked(index, x, xn, axis_size); - sum += static_cast(cub::ThreadReduce(xn, cuda::std::plus<>{})); + auto xn = load_vector(x, index, axis_size, T(0)); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + sum += static_cast(xn[i]); + } } sum = BlockReduceT{block, temp}.Sum(sum); @@ -87,11 +87,18 @@ __global__ void layer_norm( float normalizer = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size, mean); - for (int i = 0; i < N_READS; ++i) { - float t = static_cast(xn[i]) - mean; - normalizer += t * t; + if ((index + 1) * N_READS <= axis_size) { + auto xn = load_vector(x, index); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]) - mean; + normalizer += t * t; + } + } else { + for (int i = index * N_READS; i < axis_size; ++i) { + float t = static_cast(x[i]) - mean; + normalizer += t * t; + } } } normalizer = BlockReduceT{block, temp}.Sum(normalizer); @@ -100,17 +107,15 @@ __global__ void layer_norm( // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS]; - T wn[N_READS]; - T bn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size); - cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); - cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size); + auto xn = load_vector(x, index, axis_size, T(0)); + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); + auto bn = load_vector(b, index, axis_size, b_stride, T(0)); +#pragma unroll for (int i = 0; i < N_READS; ++i) { float norm = (static_cast(xn[i]) - mean) * normalizer; xn[i] = wn[i] * static_cast(norm) + bn[i]; } - cub::StoreDirectBlocked(index, out, xn, axis_size); + store_vector(out, index, xn, axis_size); } } @@ -143,9 +148,11 @@ __global__ void layer_norm_vjp( float sum = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS] = {}; - cub::LoadDirectBlocked(index, x, xn, axis_size); - sum += static_cast(cub::ThreadReduce(xn, cuda::std::plus<>{})); + auto xn = load_vector(x, index, axis_size, T(0)); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + sum += static_cast(xn[i]); + } } sum = BlockReduceF{block, temp.f}.Sum(sum); @@ -155,19 +162,28 @@ __global__ void layer_norm_vjp( // Normalizer. float3 factors = {}; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { - T xn[N_READS]; - T wn[N_READS] = {}; - T gn[N_READS] = {}; auto index = r * BLOCK_DIM + block.thread_rank(); - cub::LoadDirectBlocked(index, x, xn, axis_size, mean); - cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); - for (int i = 0; i < N_READS; i++) { - float t = static_cast(xn[i]) - mean; - float wi = wn[i]; - float gi = gn[i]; - float wg = wi * gi; - factors = plus_f3(factors, {wg, wg * t, t * t}); + auto gn = load_vector(g, index, axis_size, T(0)); + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); + + if ((index + 1) * N_READS <= axis_size) { + auto xn = load_vector(x, index); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]) - mean; + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f3(factors, {wg, wg * t, t * t}); + } + } else { + for (int i = index * N_READS; i < axis_size; ++i) { + float t = static_cast(x[i]) - mean; + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f3(factors, {wg, wg * t, t * t}); + } } } factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {}); @@ -179,12 +195,10 @@ __global__ void layer_norm_vjp( // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS]; - T wn[N_READS]; - T gn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size); - cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); + auto xn = load_vector(x, index, axis_size, T(0)); + auto gn = load_vector(g, index, axis_size, T(0)); + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); + for (int i = 0; i < N_READS; i++) { float xi = (static_cast(xn[i]) - mean) * normalizer; float wi = wn[i]; @@ -194,9 +208,9 @@ __global__ void layer_norm_vjp( wn[i] = gi * xi; } } - cub::StoreDirectBlocked(index, gx, xn, axis_size); + store_vector(gx, index, xn, axis_size); if constexpr (HAS_W) { - cub::StoreDirectBlocked(index, gw, wn, axis_size); + store_vector(gw, index, wn, axis_size); } } } @@ -257,9 +271,9 @@ void LayerNorm::eval_gpu( encoder.set_input_array(b); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; + using DataType = cuda_type_t; + constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; auto kernel = cu::layer_norm; encoder.add_kernel_node( kernel, @@ -364,10 +378,10 @@ void LayerNormVJP::eval_gpu( encoder.set_output_array(gw_temp); dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; + using DataType = cuda_type_t; + constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim( cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; auto kernel = cu::layer_norm_vjp< DataType, has_w_constant.value, diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 87cb3aedc4..269efc034b 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -5,8 +5,6 @@ #include "mlx/backend/gpu/copy.h" #include -#include -#include #include diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 48d6a82812..419f3d2179 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -10,8 +10,6 @@ #include #include #include -#include -#include namespace mlx::core { @@ -57,7 +55,7 @@ __global__ void rms_norm( const T* w, T* out, float eps, - int32_t axis_size, + uint32_t axis_size, int64_t w_stride) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); @@ -72,8 +70,8 @@ __global__ void rms_norm( float normalizer = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); + auto xn = load_vector(x, index, axis_size, T(0)); +#pragma unroll for (int i = 0; i < N_READS; ++i) { float t = static_cast(xn[i]); normalizer += t * t; @@ -85,15 +83,14 @@ __global__ void rms_norm( // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS]; - T wn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size); - cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); + auto xn = load_vector(x, index, axis_size, T(0)); + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); +#pragma unroll for (int i = 0; i < N_READS; ++i) { - float norm = static_cast(xn[i]) * normalizer; - xn[i] = wn[i] * static_cast(norm); + float y = static_cast(xn[i]) * normalizer; + xn[i] = wn[i] * static_cast(y); } - cub::StoreDirectBlocked(index, out, xn, axis_size); + store_vector(out, index, xn, axis_size); } } @@ -125,13 +122,10 @@ __global__ void rms_norm_vjp( // Normalizer. float2 factors = {}; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { - T xn[N_READS]; - T wn[N_READS] = {}; - T gn[N_READS] = {}; auto index = r * BLOCK_DIM + block.thread_rank(); - cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); - cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); + auto xn = load_vector(x, index, axis_size, T(0)); + auto gn = load_vector(g, index, axis_size, T(0)); + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); for (int i = 0; i < N_READS; i++) { float t = static_cast(xn[i]); float wi = wn[i]; @@ -148,12 +142,9 @@ __global__ void rms_norm_vjp( // Outputs. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); - T xn[N_READS]; - T wn[N_READS]; - T gn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size); - cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); + auto xn = load_vector(x, index, axis_size, T(0)); + auto gn = load_vector(g, index, axis_size, T(0)); + auto wn = load_vector(w, index, axis_size, w_stride, T(0)); for (int i = 0; i < N_READS; i++) { float xi = xn[i]; float wi = wn[i]; @@ -163,9 +154,9 @@ __global__ void rms_norm_vjp( wn[i] = static_cast(gi * xi * normalizer); } } - cub::StoreDirectBlocked(index, gx, xn, axis_size); + store_vector(gx, index, xn, axis_size); if constexpr (HAS_W) { - cub::StoreDirectBlocked(index, gw, wn, axis_size); + store_vector(gw, index, wn, axis_size); } } } @@ -223,9 +214,9 @@ void RMSNorm::eval_gpu( encoder.set_input_array(w); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; + using DataType = cuda_type_t; + constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; auto kernel = cu::rms_norm; encoder.add_kernel_node( kernel, @@ -312,11 +303,10 @@ void RMSNormVJP::eval_gpu( encoder.set_output_array(gw_temp); dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; + using DataType = cuda_type_t; + constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim( cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - constexpr int N_READS = 4; auto kernel = cu::rms_norm_vjp< DataType, has_w_constant.value, diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 9b208c4233..6f9c6c7e4d 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -32,7 +32,7 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]); + out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); } store_vector(out, index, out_vec); @@ -166,8 +166,7 @@ void ternary_op_gpu_inplace( } else { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - // TODO: Choose optimized value based on type size. - constexpr int N_READS = 4; + constexpr int N_READS = 16 / sizeof(DType); auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( kernel, diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 83bf834178..68d04b9eaf 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -30,7 +30,7 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = Op{}(in_vec.val[i]); + out_vec[i] = Op{}(in_vec[i]); } store_vector(out, index, out_vec); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index bbea9ad8e4..5bc51f2977 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3049,6 +3049,25 @@ class TestOps(mlx_tests.MLXTestCase): out = mx.power(mx.array(0j), float("nan")) self.assertTrue(mx.isnan(out)) + def test_irregular_alignments(self): + # Unaligned unary op + a = mx.ones((64, 1)) + b = -a[1:] + self.assertTrue(mx.all(b == -1.0)) + + # Unaligned binary op + a = mx.ones((64, 1)) + b = a[1:] + c = b + b + self.assertTrue(mx.all(c == 2.0)) + + # Unaligned ternary op + a = mx.ones((64, 1)) + b = mx.zeros((63, 1)) + c = mx.ones((63, 1)).astype(mx.bool_) + d = mx.where(c, a[1:], b) + self.assertTrue(mx.all(d == 1.0)) + class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self):