mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-22 05:08:08 +08:00
[CUDA] Fix alpha not respected when using bias epilogue (#2578)
This commit is contained in:
@@ -248,11 +248,19 @@ void CublasGemm::run(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides) {
|
const Strides& b_batch_strides,
|
||||||
|
float alpha) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
if (batch_count / batch_shape.back() > 1) {
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
run_batched(
|
run_batched(
|
||||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
encoder,
|
||||||
|
out,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
alpha);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,7 +268,13 @@ void CublasGemm::run(
|
|||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
execute(
|
||||||
|
encoder,
|
||||||
|
out.data<void>(),
|
||||||
|
a.data<void>(),
|
||||||
|
b.data<void>(),
|
||||||
|
nullptr,
|
||||||
|
alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
|
@@ -64,7 +64,8 @@ class CublasGemm {
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides);
|
const Strides& b_batch_strides,
|
||||||
|
float alpha = 1.0f);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@@ -87,7 +88,8 @@ class CublasGemm {
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides);
|
const Strides& b_batch_strides,
|
||||||
|
float alpha);
|
||||||
|
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
|
@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides) {
|
const Strides& b_batch_strides,
|
||||||
|
float alpha) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
|
|||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||||
nullptr);
|
nullptr,
|
||||||
|
alpha);
|
||||||
a_it.step();
|
a_it.step();
|
||||||
b_it.step();
|
b_it.step();
|
||||||
}
|
}
|
||||||
|
@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const Strides& b_batch_strides) {
|
const Strides& b_batch_strides,
|
||||||
|
float alpha) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
|
|||||||
reinterpret_cast<void*>(out_pointers),
|
reinterpret_cast<void*>(out_pointers),
|
||||||
reinterpret_cast<void*>(a_pointers),
|
reinterpret_cast<void*>(a_pointers),
|
||||||
reinterpret_cast<void*>(b_pointers),
|
reinterpret_cast<void*>(b_pointers),
|
||||||
nullptr);
|
nullptr,
|
||||||
|
alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasGemm::run_batched(
|
void CublasGemm::run_batched(
|
||||||
|
@@ -41,7 +41,8 @@ void gemm_and_bias(
|
|||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
void* bias = nullptr) {
|
void* bias = nullptr,
|
||||||
|
float alpha = 1.0f) {
|
||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||||
|
|
||||||
@@ -94,7 +95,8 @@ void gemm_and_bias(
|
|||||||
if (bias) {
|
if (bias) {
|
||||||
gemm.set_bias(bias);
|
gemm.set_bias(bias);
|
||||||
}
|
}
|
||||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
gemm.run(
|
||||||
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -169,7 +171,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out,
|
out,
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
c.data<void>());
|
c.data<void>(),
|
||||||
|
alpha_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -594,124 +594,123 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
# Batched matmul
|
# Batched matmul
|
||||||
alpha = 0.5
|
alpha = 0.5
|
||||||
beta = 2.0
|
for beta in (1.0, 2.0):
|
||||||
|
# c must broadcast to the output shape
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
||||||
|
|
||||||
# c must broadcast to the output shape
|
# Regular batched case
|
||||||
with self.assertRaises(ValueError):
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||||
|
|
||||||
# Regular batched case
|
a_mlx = mx.array(a_npy)
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
b_mlx = mx.array(b_npy)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
|
||||||
|
|
||||||
a_mlx = mx.array(a_npy)
|
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||||
b_mlx = mx.array(b_npy)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Batched and transposed matmul
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
# Batched and transposed matmul
|
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
b_mlx = mx.array(b_npy)
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||||
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
# Batched matmul with simple broadcast
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
a_mlx = mx.array(a_npy)
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
b_mlx = mx.array(b_npy)
|
||||||
# Batched matmul with simple broadcast
|
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
|
||||||
|
|
||||||
a_mlx = mx.array(a_npy)
|
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||||
b_mlx = mx.array(b_npy)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
# Matmul with vector
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
for c_shape in ((1,), (128,), (32, 128)):
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
# Matmul with vector
|
c_mlx = mx.array(c_npy)
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
|
||||||
a_mlx = mx.array(a_npy)
|
|
||||||
b_mlx = mx.array(b_npy)
|
|
||||||
|
|
||||||
for c_shape in ((1,), (128,), (32, 128)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Matmul with vector
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
# Matmul with vector
|
for c_shape in ((1,), (32, 128)):
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
c_mlx = mx.array(c_npy)
|
||||||
a_mlx = mx.array(a_npy)
|
|
||||||
b_mlx = mx.array(b_npy)
|
|
||||||
|
|
||||||
for c_shape in ((1,), (32, 128)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Split K specializtion
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||||
|
|
||||||
# Split K specializtion
|
a_mlx = mx.array(a_npy)
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
b_mlx = mx.array(b_npy)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
|
||||||
|
|
||||||
a_mlx = mx.array(a_npy)
|
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||||
b_mlx = mx.array(b_npy)
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
c_npy = np.ones(c_shape).astype(np.float32)
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
c_mlx = mx.array(c_npy)
|
|
||||||
|
|
||||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
# Transposed c
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
a = mx.ones((10, 5)).T
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
|
||||||
|
expected = beta * a + alpha * (b @ a)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
# Transposed c
|
# Broadcast c
|
||||||
a = mx.ones((10, 5)).T
|
a = mx.ones((5, 5))
|
||||||
b = mx.ones((5, 5))
|
b = mx.ones((5, 5))
|
||||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
c = mx.ones((1, 5))
|
||||||
expected = 1.5 * a + 0.5 * (b @ a)
|
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
|
||||||
self.assertTrue(mx.allclose(expected, out))
|
expected = beta * c + alpha * (a @ b)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
# Broadcast c
|
|
||||||
a = mx.ones((5, 5))
|
|
||||||
b = mx.ones((5, 5))
|
|
||||||
c = mx.ones((1, 5))
|
|
||||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
|
||||||
expected = 1.5 * c + 0.5 * (a @ b)
|
|
||||||
self.assertTrue(mx.allclose(expected, out))
|
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
@@ -724,33 +723,32 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
||||||
|
|
||||||
alpha = 2.0
|
alpha = 2.0
|
||||||
beta = 0.5
|
for beta in (1.0, 0.5):
|
||||||
|
f_test = make_addmm(alpha, beta)
|
||||||
|
f_ref = make_ref_addmm(alpha, beta)
|
||||||
|
|
||||||
f_test = make_addmm(alpha, beta)
|
for B, M, N, K in shapes:
|
||||||
f_ref = make_ref_addmm(alpha, beta)
|
cotan = mx.ones((B, M, N))
|
||||||
|
c = mx.random.normal((B, M, N))
|
||||||
|
a = mx.random.normal((B, M, K))
|
||||||
|
b = mx.random.normal((B, K, N))
|
||||||
|
|
||||||
for B, M, N, K in shapes:
|
out_ref, dout_ref = mx.vjp(
|
||||||
cotan = mx.ones((B, M, N))
|
f_ref,
|
||||||
c = mx.random.normal((B, M, N))
|
[c, a, b],
|
||||||
a = mx.random.normal((B, M, K))
|
[cotan],
|
||||||
b = mx.random.normal((B, K, N))
|
)
|
||||||
|
out_test, dout_test = mx.vjp(
|
||||||
|
f_test,
|
||||||
|
[c, a, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
|
||||||
out_ref, dout_ref = mx.vjp(
|
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||||
f_ref,
|
|
||||||
[c, a, b],
|
|
||||||
[cotan],
|
|
||||||
)
|
|
||||||
out_test, dout_test = mx.vjp(
|
|
||||||
f_test,
|
|
||||||
[c, a, b],
|
|
||||||
[cotan],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
for r, t in zip(dout_ref, dout_test):
|
||||||
|
self.assertEqual(r.shape, t.shape)
|
||||||
for r, t in zip(dout_ref, dout_test):
|
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||||
self.assertEqual(r.shape, t.shape)
|
|
||||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
|
||||||
|
|
||||||
def test_empty_matmul(self):
|
def test_empty_matmul(self):
|
||||||
a = mx.array([[], []]).T
|
a = mx.array([[], []]).T
|
||||||
|
Reference in New Issue
Block a user