mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Fix matvec vector stride bug (#1168)
This commit is contained in:
parent
e7a2a3dcd1
commit
9f0df51f8d
@ -710,15 +710,19 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto check_transpose = [&copies, &s](const array& arr, bool is_vector) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@ -729,12 +733,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, a_cols, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1);
|
||||
auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
@ -517,6 +517,70 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
||||
|
||||
def test_mismatch_stride_mm(self):
|
||||
np.random.seed(0)
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32)
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
# Matmul with batches
|
||||
c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, :]
|
||||
c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, :]
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
# Matvec with batches
|
||||
c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, 2:3]
|
||||
c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, 2:3]
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
# Matmul with slice
|
||||
c_npy = a_npy[:, :8, :] @ b_npy[:, :, :8]
|
||||
c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, :8]
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
# Matmul with slice
|
||||
c_npy = a_npy[:, :, :8] @ b_npy[:, :8, :]
|
||||
c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :8, :]
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
# Matmul transpose with slice
|
||||
c_npy = a_npy[:, :8, :] @ b_npy[:, :8, :].swapaxes(-1, -2)
|
||||
c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :8, :].swapaxes(-1, -2)
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
# Matmul transpose with slice
|
||||
c_npy = a_npy[:, :, :8] @ b_npy[:, :, :8].swapaxes(-1, -2)
|
||||
c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :, :8].swapaxes(-1, -2)
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
# Matvec with slice
|
||||
c_npy = a_npy[:, :8, :] @ b_npy[:, :, 6:7]
|
||||
c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, 6:7]
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
# Matvec with slice
|
||||
c_npy = a_npy[:, :, :8] @ b_npy[:, 3:11, 2:3]
|
||||
c_mlx = a_mlx[:, :, :8] @ b_mlx[:, 3:11, 2:3]
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
|
||||
|
||||
def test_addmm(self):
|
||||
np.random.seed(0)
|
||||
# Batched matmul
|
||||
|
Loading…
Reference in New Issue
Block a user