mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
d75ae52ecd
commit
d40a04f8dc
@ -46,6 +46,9 @@ inline void matmul_cblas_general(
|
|||||||
size_t N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
size_t K = a.shape(-1);
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
|
if (M == 0 || N == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (K == 0) {
|
if (K == 0) {
|
||||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||||
return;
|
return;
|
||||||
@ -94,6 +97,9 @@ inline void matmul_bnns_general(
|
|||||||
size_t N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
size_t K = a.shape(-1);
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
|
if (M == 0 || N == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (K == 0) {
|
if (K == 0) {
|
||||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||||
return;
|
return;
|
||||||
|
@ -132,7 +132,9 @@ inline void matmul_common_general(
|
|||||||
size_t M = a.shape(-2);
|
size_t M = a.shape(-2);
|
||||||
size_t N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
size_t K = a.shape(-1);
|
size_t K = a.shape(-1);
|
||||||
|
if (M == 0 || N == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (K == 0) {
|
if (K == 0) {
|
||||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||||
return;
|
return;
|
||||||
|
@ -1323,8 +1323,8 @@ array mean(
|
|||||||
for (int axis : axes) {
|
for (int axis : axes) {
|
||||||
if (axis < -ndim || axis >= ndim) {
|
if (axis < -ndim || axis >= ndim) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[mean] axis " << axis + " is out of bounds for array with "
|
msg << "[mean] axis " << axis << " is out of bounds for array with "
|
||||||
<< ndim + " dimensions.";
|
<< ndim << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1364,7 +1364,7 @@ array var(
|
|||||||
|
|
||||||
if (ddof != 0) {
|
if (ddof != 0) {
|
||||||
auto nelements = compute_number_of_elements(a, axes);
|
auto nelements = compute_number_of_elements(a, axes);
|
||||||
float factor = nelements / (nelements - ddof);
|
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
|
||||||
v = multiply(v, array(factor, dtype), s);
|
v = multiply(v, array(factor, dtype), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -582,6 +582,25 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(r.shape, t.shape)
|
self.assertEqual(r.shape, t.shape)
|
||||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||||
|
|
||||||
|
def test_empty_matmul(self):
|
||||||
|
a = mx.array([[], []]).T
|
||||||
|
b = mx.array([[1.0, 2.0], [2.0, 3.0]])
|
||||||
|
c = a @ b
|
||||||
|
mx.eval(c)
|
||||||
|
self.assertEqual(c.shape, (0, 2))
|
||||||
|
|
||||||
|
a = mx.array([[1.0, 2.0], [2.0, 3.0]])
|
||||||
|
b = mx.array([[], []])
|
||||||
|
c = a @ b
|
||||||
|
mx.eval(c)
|
||||||
|
self.assertEqual(c.shape, (2, 0))
|
||||||
|
|
||||||
|
a = mx.array([[], []]).T
|
||||||
|
b = mx.array([[], []])
|
||||||
|
c = a @ b
|
||||||
|
mx.eval(c)
|
||||||
|
self.assertEqual(c.shape, (0, 0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
@ -690,6 +690,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])
|
self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])
|
||||||
self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])
|
self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])
|
||||||
|
|
||||||
|
x = mx.array([1.0, 2.0])
|
||||||
|
out = mx.var(x, ddof=2)
|
||||||
|
self.assertEqual(out.item(), float("inf"))
|
||||||
|
|
||||||
|
x = mx.array([1.0, 2.0])
|
||||||
|
out = mx.var(x, ddof=3)
|
||||||
|
self.assertEqual(out.item(), float("inf"))
|
||||||
|
|
||||||
def test_abs(self):
|
def test_abs(self):
|
||||||
a = mx.array([-1.0, 1.0, -2.0, 3.0])
|
a = mx.array([-1.0, 1.0, -2.0, 3.0])
|
||||||
result = mx.abs(a)
|
result = mx.abs(a)
|
||||||
|
Loading…
Reference in New Issue
Block a user