mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-10 05:59:04 +08:00
Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)
* start to cleanup/unify accelerate and common back-ends * more progress * simplify * add half type and allow infs in simd exp * unify softmax + quantized, more dispatches to simd quantized mm * add sin/cos, use simd in vector-scalar ops * faster CPU vectorize quant * faster erf/erfinv
This commit is contained in:
@@ -741,7 +741,7 @@ void init_array(nb::module_& m) {
|
||||
[](const mx::array& a) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise inversion.");
|
||||
"Floating point types not allowed with bitwise inversion.");
|
||||
}
|
||||
if (a.dtype() != mx::bool_) {
|
||||
throw std::invalid_argument(
|
||||
@@ -791,7 +791,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
"Floating point types not allowed with bitwise or.");
|
||||
}
|
||||
return mx::bitwise_or(a, b);
|
||||
},
|
||||
@@ -806,7 +806,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
"Floating point types not allowed with bitwise or.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::bitwise_or(a, b));
|
||||
return a;
|
||||
@@ -838,7 +838,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or left shift.");
|
||||
"Floating point types not allowed with left shift.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::left_shift(a, b));
|
||||
return a;
|
||||
@@ -870,7 +870,7 @@ void init_array(nb::module_& m) {
|
||||
if (mx::issubdtype(a.dtype(), mx::inexact) ||
|
||||
mx::issubdtype(b.dtype(), mx::inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or right shift.");
|
||||
"Floating point types not allowed with right shift.");
|
||||
}
|
||||
a.overwrite_descriptor(mx::right_shift(a, b));
|
||||
return a;
|
||||
|
||||
@@ -289,6 +289,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
|
||||
self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])
|
||||
|
||||
z = -mx.ones(64) % mx.full(64, 2)
|
||||
self.assertTrue(mx.array_equal(z, mx.ones(64)))
|
||||
|
||||
def test_comparisons(self):
|
||||
a = mx.array([0.0, 1.0, 5.0])
|
||||
b = mx.array([-1.0, 2.0, 5.0])
|
||||
|
||||
@@ -207,8 +207,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
x = 1e-1 * mx.random.normal(shape=x_shape, key=k1)
|
||||
w = 1e-1 * mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
|
||||
Reference in New Issue
Block a user