Propagate nans in binary ops (#579)

* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
This commit is contained in:
Awni Hannun 2024-01-29 11:19:38 -08:00 committed by GitHub
parent 37d98ba6ff
commit 3c2f192345
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 119 additions and 50 deletions

View File

@ -42,7 +42,7 @@ jobs:
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension

View File

@ -44,6 +44,13 @@ def time_matmul():
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@ -101,6 +108,7 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_maximum()
time_exp()
time_negative()
time_logsumexp()

View File

@ -46,6 +46,11 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@ -89,6 +94,11 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
@ -201,4 +211,4 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core
} // namespace mlx::core

View File

@ -50,6 +50,8 @@ DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
@ -396,47 +398,6 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];

View File

@ -233,14 +233,33 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x < y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {

View File

@ -6,6 +6,8 @@
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
@ -128,6 +130,11 @@ inline void matmul_common_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,

View File

@ -58,6 +58,9 @@ struct LessEqual {
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
@ -67,20 +70,48 @@ struct LogAddExp {
};
struct Maximum {
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x >= y ? x : y;
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x <= y ? x : y;
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
@ -389,4 +420,4 @@ instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)

View File

@ -189,6 +189,9 @@ array full(
const array& vals,
Dtype dtype,
StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
throw std::invalid_argument("[full] Negative dimensions not allowed.");
}
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
return array(shape, dtype, std::make_unique<Full>(to_stream(s)), {in});
}

View File

@ -386,6 +386,11 @@ class TestOps(mlx_tests.MLXTestCase):
expected = [0, -7, 3]
self.assertListEqual(mx.minimum(x, y).tolist(), expected)
a = mx.array([float("nan")])
b = mx.array([0.0])
self.assertTrue(math.isnan(mx.minimum(a, b).item()))
self.assertTrue(math.isnan(mx.minimum(b, a).item()))
def test_maximum(self):
x = mx.array([0.0, -5, 10.0])
y = mx.array([1.0, -7.0, 3.0])
@ -393,6 +398,11 @@ class TestOps(mlx_tests.MLXTestCase):
expected = [1, -5, 10]
self.assertListEqual(mx.maximum(x, y).tolist(), expected)
a = mx.array([float("nan")])
b = mx.array([0.0])
self.assertTrue(math.isnan(mx.maximum(a, b).item()))
self.assertTrue(math.isnan(mx.maximum(b, a).item()))
def test_floor(self):
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
expected = [-23, 19, -27, 9, 0, -np.inf, np.inf]
@ -760,6 +770,10 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected))
a = mx.array([float("nan")])
b = mx.array([0.0])
self.assertTrue(math.isnan(mx.logaddexp(a, b).item()))
def test_log(self):
a = mx.array([1, 0.5, 10, 100])
result = mx.log(a)
@ -1761,6 +1775,16 @@ class TestOps(mlx_tests.MLXTestCase):
)
self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile)
def test_empty_matmuls(self):
a = mx.array([])
b = mx.array([])
self.assertEqual(mx.inner(a, b).item(), 0.0)
a = mx.zeros((10, 0))
b = mx.zeros((0, 10))
out = a @ b
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
if __name__ == "__main__":
unittest.main()

View File

@ -151,6 +151,12 @@ TEST_CASE("test astype") {
}
TEST_CASE("test full") {
// Check throws on bad shape
{
CHECK_THROWS(full({-5, 0}, 0));
CHECK_THROWS(full({0, -5}, 0));
}
// Check full works for different types
{
auto x = full({}, 0);