mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-04 01:36:42 +08:00
fix gemv regression (#2445)
This commit is contained in:
parent
b405591249
commit
d32519c8ee
@ -43,10 +43,18 @@ struct alignas(sizeof(T) * N) AlignedVector {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
inline __device__ bool is_aligned(T* x) {
|
inline __host__ __device__ bool is_aligned(T* x) {
|
||||||
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
|
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ AlignedVector<T, N> unsafe_load_vector(
|
||||||
|
const T* ptr,
|
||||||
|
uint32_t offset) {
|
||||||
|
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||||
|
return from[offset];
|
||||||
|
}
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
inline __device__ AlignedVector<T, N> load_vector(
|
inline __device__ AlignedVector<T, N> load_vector(
|
||||||
const T* ptr,
|
const T* ptr,
|
||||||
@ -101,6 +109,13 @@ inline __device__ AlignedVector<T, N> load_vector(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
inline __device__ void
|
||||||
|
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||||
|
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||||
|
to[offset] = vec;
|
||||||
|
}
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
inline __device__ void
|
inline __device__ void
|
||||||
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||||
|
@ -27,8 +27,9 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
|||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int col = n_per_thread * warp.thread_rank(); col < cols;
|
for (int col = n_per_thread * warp.thread_rank(); col < cols;
|
||||||
col += (WARP_SIZE * n_per_thread)) {
|
col += (WARP_SIZE * n_per_thread)) {
|
||||||
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
|
auto local_mat =
|
||||||
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
|
unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
|
||||||
|
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < n_per_thread; ++j) {
|
for (int j = 0; j < n_per_thread; ++j) {
|
||||||
sum +=
|
sum +=
|
||||||
@ -127,9 +128,13 @@ void gemv(
|
|||||||
rows = M;
|
rows = M;
|
||||||
}
|
}
|
||||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||||
int n_per_t = 4;
|
int n_per_t;
|
||||||
while (K % (n_per_t * WARP_SIZE) != 0) {
|
if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
|
||||||
n_per_t >>= 1;
|
n_per_t = 4;
|
||||||
|
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
|
||||||
|
n_per_t = 2;
|
||||||
|
} else {
|
||||||
|
n_per_t = 1;
|
||||||
}
|
}
|
||||||
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
|
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
|
||||||
if (batch_count == 1) {
|
if (batch_count == 1) {
|
||||||
|
@ -47,7 +47,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
|
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
|
||||||
|
|
||||||
def test_matmul_unaligned(self):
|
def test_matmul_unaligned(self):
|
||||||
if not mx.metal.is_available():
|
if not mx.is_available(mx.gpu):
|
||||||
return
|
return
|
||||||
|
|
||||||
for dtype in self.dtypes:
|
for dtype in self.dtypes:
|
||||||
@ -61,8 +61,15 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
shape_b = (dim + p, dim + p)
|
shape_b = (dim + p, dim + p)
|
||||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||||
|
|
||||||
|
def test_matvec_unaligned(self):
|
||||||
|
a = mx.random.normal(shape=(4, 128))
|
||||||
|
b = mx.random.normal(shape=(129,))[1:]
|
||||||
|
out = a @ b
|
||||||
|
np_out = np.array(a) @ np.array(b)
|
||||||
|
self.assertTrue(np.allclose(out, np_out))
|
||||||
|
|
||||||
def test_matmul_shapes(self):
|
def test_matmul_shapes(self):
|
||||||
if not mx.metal.is_available():
|
if not mx.is_available(mx.gpu):
|
||||||
return
|
return
|
||||||
|
|
||||||
shapes = [
|
shapes = [
|
||||||
@ -1274,7 +1281,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
def test_gemv_gemm_same_precision(self):
|
def test_gemv_gemm_same_precision(self):
|
||||||
mx.random.seed(0)
|
mx.random.seed(0)
|
||||||
N = 256
|
N = 256
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
t = mx.bfloat16
|
t = mx.bfloat16
|
||||||
a = mx.random.normal([1, N]).astype(t)
|
a = mx.random.normal([1, N]).astype(t)
|
||||||
b = mx.concatenate([a, a], axis=0).astype(t)
|
b = mx.concatenate([a, a], axis=0).astype(t)
|
||||||
|
Loading…
Reference in New Issue
Block a user