diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index f9a5c4e06f..bc055c9df2 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -43,10 +43,18 @@ struct alignas(sizeof(T) * N) AlignedVector { }; template -inline __device__ bool is_aligned(T* x) { +inline __host__ __device__ bool is_aligned(T* x) { return (reinterpret_cast(x) % (N * sizeof(T))) == 0; } +template +inline __device__ AlignedVector unsafe_load_vector( + const T* ptr, + uint32_t offset) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; +} + template inline __device__ AlignedVector load_vector( const T* ptr, @@ -101,6 +109,13 @@ inline __device__ AlignedVector load_vector( } } +template +inline __device__ void +unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; +} + template inline __device__ void store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index 163945e79a..55333adea3 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -27,8 +27,9 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { float sum = 0.0f; for (int col = n_per_thread * warp.thread_rank(); col < cols; col += (WARP_SIZE * n_per_thread)) { - auto local_mat = load_vector(mat + row * cols + col, 0); - auto local_vec = load_vector(vec + col, 0); + auto local_mat = + unsafe_load_vector(mat + row * cols + col, 0); + auto local_vec = unsafe_load_vector(vec + col, 0); #pragma unroll for (int j = 0; j < n_per_thread; ++j) { sum += @@ -127,9 +128,13 @@ void gemv( rows = M; } uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; - int n_per_t = 4; - while (K % (n_per_t * WARP_SIZE) != 0) { - n_per_t >>= 1; + int n_per_t; + if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) { + n_per_t = 4; + } else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) { + n_per_t = 2; + } else { + n_per_t = 1; } dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { if (batch_count == 1) { diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 7cc39f06ab..bc675535af 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -47,7 +47,7 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5)) def test_matmul_unaligned(self): - if not mx.metal.is_available(): + if not mx.is_available(mx.gpu): return for dtype in self.dtypes: @@ -61,8 +61,15 @@ class TestBlas(mlx_tests.MLXTestCase): shape_b = (dim + p, dim + p) self.__gemm_test(shape_a, shape_b, np_dtype) + def test_matvec_unaligned(self): + a = mx.random.normal(shape=(4, 128)) + b = mx.random.normal(shape=(129,))[1:] + out = a @ b + np_out = np.array(a) @ np.array(b) + self.assertTrue(np.allclose(out, np_out)) + def test_matmul_shapes(self): - if not mx.metal.is_available(): + if not mx.is_available(mx.gpu): return shapes = [ @@ -1274,7 +1281,7 @@ class TestBlas(mlx_tests.MLXTestCase): def test_gemv_gemm_same_precision(self): mx.random.seed(0) N = 256 - if mx.metal.is_available(): + if mx.is_available(mx.gpu): t = mx.bfloat16 a = mx.random.normal([1, N]).astype(t) b = mx.concatenate([a, a], axis=0).astype(t)