mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Propagate nans in binary ops (#579)
* propagate nans in binary ops * handle empty matmul * cpu minimum/maximum propagate nan * benchmark maximum * add min as well * throw on negative indices with full * verbose on linux * fix matmul for zero K
This commit is contained in:
parent
37d98ba6ff
commit
3c2f192345
@ -42,7 +42,7 @@ jobs:
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests
|
||||
python3 -m unittest discover python/tests -v
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
|
@ -44,6 +44,13 @@ def time_matmul():
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
|
||||
def time_maximum():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
b = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
mx.eval(a, b)
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@ -101,6 +108,7 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
time_logsumexp()
|
||||
|
@ -46,6 +46,11 @@ inline void matmul_cblas_general(
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
@ -89,6 +94,11 @@ inline void matmul_bnns_general(
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
|
@ -50,6 +50,8 @@ DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
@ -396,47 +398,6 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return (x > y) ? x : y; },
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* out, int n) {
|
||||
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return (x < y) ? x : y; },
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* out, int n) {
|
||||
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
|
@ -233,14 +233,33 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
|
||||
if (is_floating_point(out.dtype())) {
|
||||
binary(a, b, out, [](auto x, auto y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return (x > y) ? x : y;
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
if (is_floating_point(out.dtype())) {
|
||||
binary(a, b, out, [](auto x, auto y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return (x < y) ? x : y;
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@ -128,6 +130,11 @@ inline void matmul_common_general(
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
|
@ -58,6 +58,9 @@ struct LessEqual {
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
@ -67,20 +70,48 @@ struct LogAddExp {
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x >= y ? x : y;
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x <= y ? x : y;
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -189,6 +189,9 @@ array full(
|
||||
const array& vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||
}
|
||||
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
|
||||
return array(shape, dtype, std::make_unique<Full>(to_stream(s)), {in});
|
||||
}
|
||||
|
@ -386,6 +386,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = [0, -7, 3]
|
||||
self.assertListEqual(mx.minimum(x, y).tolist(), expected)
|
||||
|
||||
a = mx.array([float("nan")])
|
||||
b = mx.array([0.0])
|
||||
self.assertTrue(math.isnan(mx.minimum(a, b).item()))
|
||||
self.assertTrue(math.isnan(mx.minimum(b, a).item()))
|
||||
|
||||
def test_maximum(self):
|
||||
x = mx.array([0.0, -5, 10.0])
|
||||
y = mx.array([1.0, -7.0, 3.0])
|
||||
@ -393,6 +398,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = [1, -5, 10]
|
||||
self.assertListEqual(mx.maximum(x, y).tolist(), expected)
|
||||
|
||||
a = mx.array([float("nan")])
|
||||
b = mx.array([0.0])
|
||||
self.assertTrue(math.isnan(mx.maximum(a, b).item()))
|
||||
self.assertTrue(math.isnan(mx.maximum(b, a).item()))
|
||||
|
||||
def test_floor(self):
|
||||
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
|
||||
expected = [-23, 19, -27, 9, 0, -np.inf, np.inf]
|
||||
@ -760,6 +770,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
a = mx.array([float("nan")])
|
||||
b = mx.array([0.0])
|
||||
self.assertTrue(math.isnan(mx.logaddexp(a, b).item()))
|
||||
|
||||
def test_log(self):
|
||||
a = mx.array([1, 0.5, 10, 100])
|
||||
result = mx.log(a)
|
||||
@ -1761,6 +1775,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile)
|
||||
|
||||
def test_empty_matmuls(self):
|
||||
a = mx.array([])
|
||||
b = mx.array([])
|
||||
self.assertEqual(mx.inner(a, b).item(), 0.0)
|
||||
|
||||
a = mx.zeros((10, 0))
|
||||
b = mx.zeros((0, 10))
|
||||
out = a @ b
|
||||
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -151,6 +151,12 @@ TEST_CASE("test astype") {
|
||||
}
|
||||
|
||||
TEST_CASE("test full") {
|
||||
// Check throws on bad shape
|
||||
{
|
||||
CHECK_THROWS(full({-5, 0}, 0));
|
||||
CHECK_THROWS(full({0, -5}, 0));
|
||||
}
|
||||
|
||||
// Check full works for different types
|
||||
{
|
||||
auto x = full({}, 0);
|
||||
|
Loading…
Reference in New Issue
Block a user