From e6418781abd4232fcb3613cc300a025ccf5f405b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 25 Feb 2024 08:39:55 -0800 Subject: [PATCH] Fix logsumexp edge case (#740) * fix logsumexp * fix inf constant * also fix power grad * fix ternary dispatch --- mlx/backend/common/ops.h | 4 ++ mlx/backend/metal/kernels/CMakeLists.txt | 2 + mlx/backend/metal/kernels/binary.metal | 10 --- mlx/backend/metal/kernels/ternary.metal | 29 ++++++-- mlx/backend/metal/kernels/unary.h | 4 ++ mlx/backend/metal/primitives.cpp | 85 +++++++++++++----------- mlx/ops.cpp | 6 +- mlx/primitives.cpp | 5 +- python/tests/test_autograd.py | 8 +++ python/tests/test_compile.py | 9 +++ python/tests/test_load.py | 8 ++- tests/ops_tests.cpp | 6 +- 12 files changed, 112 insertions(+), 64 deletions(-) diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index 560296622d..b5b0953b2a 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -7,6 +7,10 @@ namespace mlx::core::detail { +namespace { +constexpr float inf = std::numeric_limits::infinity(); +} // namespace + typedef union { int i; float f; diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 2b97ff76c1..326e0760ee 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -3,11 +3,13 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/atomic.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h + ${CMAKE_CURRENT_SOURCE_DIR}/binary.h ${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h + ${CMAKE_CURRENT_SOURCE_DIR}/unary.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 4d449ab691..eff687231a 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -2,16 +2,6 @@ #include "mlx/backend/metal/kernels/binary.h" -template -[[kernel]] void binary_op_s2s( - device const T* a, - device const T* b, - device U* c, - uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[0]); -} - - template [[kernel]] void binary_op_ss( device const T* a, diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index f3021fc110..c351bed176 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -7,6 +7,16 @@ #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/ternary.h" +template +[[kernel]] void ternary_op_v( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + uint index [[thread_position_in_grid]]) { + d[index] = Op()(a[index], b[index], c[index]); +} + template [[kernel]] void ternary_op_g_nd1( device const bool* a, @@ -94,6 +104,15 @@ template d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); } +#define instantiate_ternary_v(name, type, op) \ + template [[host_name(name)]] \ + [[kernel]] void ternary_op_v( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + uint index [[thread_position_in_grid]]); \ + #define instantiate_ternary_g(name, type, op) \ template [[host_name(name)]] \ [[kernel]] void ternary_op_g( \ @@ -160,14 +179,10 @@ template instantiate_ternary_g_dim(name, type, op, 5) \ #define instantiate_ternary_all(name, tname, type, op) \ + instantiate_ternary_v("v" #name #tname, type, op) \ instantiate_ternary_g("g" #name #tname, type, op) \ instantiate_ternary_g_nd("g" #name #tname, type, op) \ -#define instantiate_ternary_float(name, op) \ - instantiate_ternary_all(name, float16, half, op) \ - instantiate_ternary_all(name, float32, float, op) \ - instantiate_ternary_all(name, bfloat16, bfloat16_t, op) - #define instantiate_ternary_types(name, op) \ instantiate_ternary_all(name, bool_, bool, op) \ instantiate_ternary_all(name, uint8, uint8_t, op) \ @@ -178,7 +193,9 @@ template instantiate_ternary_all(name, int16, int16_t, op) \ instantiate_ternary_all(name, int32, int32_t, op) \ instantiate_ternary_all(name, int64, int64_t, op) \ + instantiate_ternary_all(name, float16, half, op) \ + instantiate_ternary_all(name, float32, float, op) \ + instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \ instantiate_ternary_all(name, complex64, complex64_t, op) \ - instantiate_ternary_float(name, op) instantiate_ternary_types(select, Select) diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 6d086b7753..e0d80ab102 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -9,6 +9,10 @@ #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/utils.h" +namespace { +constant float inf = metal::numeric_limits::infinity(); +} + struct Abs { template T operator()(T x) { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 301adcdeae..0f2716a1b6 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -240,11 +239,15 @@ void ternary_op( auto& strides_out = strides[3]; std::ostringstream kname; - kname << "g"; - kname << op << type_to_name(b); - if (topt == TernaryOpType::General && - shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { - kname << "_" << shape.size(); + if (topt == TernaryOpType::General) { + kname << "g"; + kname << op << type_to_name(b); + if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << "_" << shape.size(); + } + } else { + kname << "v"; + kname << op << type_to_name(b); } auto& s = out.primitive().stream(); @@ -257,44 +260,46 @@ void ternary_op( set_array_buffer(compute_encoder, c, 2); set_array_buffer(compute_encoder, out, 3); - auto ndim = shape.size(); - if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); - compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); + if (topt == TernaryOpType::General) { + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); - if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 8); + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 8); + } + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); } - } else if (ndim > 0) { - // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); - } else { - // For 0-dim we still need to bind something to these buffers since the - // current ternary kernels always access the strides. - size_t dummy_stride = 0; - int dummy_shape = 0; - compute_encoder->setBytes(&dummy_shape, sizeof(int), 4); - compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 5); - compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 6); - compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 7); - compute_encoder->setBytes(&ndim, sizeof(int), 8); - } - // Launch up to 3D grid of threads - size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; - size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; - size_t rest = out.size() / (dim0 * dim1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size != 1024) { - throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + MTL::Size group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + // Launch a 1D grid of threads + size_t nthreads = out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); } - MTL::Size group_dims = get_block_dims(dim0, dim1, rest); - MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder->dispatchThreads(grid_dims, group_dims); } void unary_op( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2c3cf55abf..df933a5159 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1755,7 +1755,11 @@ array logsumexp( StreamOrDevice s /* = {}*/) { auto maxval = stop_gradient(max(a, axes, true, s)); auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s); - return add(out, reshape(maxval, out.shape(), s), s); + out = add(out, reshape(maxval, out.shape(), s), s); + if (!keepdims) { + maxval = squeeze(maxval, axes, s); + } + return where(isinf(maxval, s), maxval, out, s); } array logsumexp( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 96fa310043..de2c010929 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2043,7 +2043,10 @@ std::vector Power::vjp( primals[1], stream())); } else { - vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream())); + auto& exp = outputs[0]; + auto exp_vjp = multiply(log(primals[0], stream()), outputs[0], stream()); + // 0 * log 0 -> 0 + vjps.push_back(where(exp, exp_vjp, array(0.0f, exp.dtype()), stream())); } vjps.back() = multiply(cotangents[0], vjps.back(), stream()); } diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 6c5c922b15..28054fb9b0 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase): _, vjps = mx.vjp(func, (arr,), (cotan,)) self.assertEqual(vjps[0].item(), 8.0) + def test_power_grad(self): + def fun(x, y): + res = x - y + return res**x + + grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0)) + self.assertEqual(grad.item(), 1.0) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index e531344826..2d0b22cddd 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -539,6 +539,15 @@ class TestCompile(mlx_tests.MLXTestCase): z = fun(mx.array(1), "two") self.assertEqual(z.item(), 3) + def test_compile_inf(self): + + @mx.compile + def fun(x): + return mx.isinf(x + 2) + + out = fun(mx.array([0.0])) + self.assertEqual(out.item(), False) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_load.py b/python/tests/test_load.py index fdf06041ac..2a0ae35c3c 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase): def test_save_and_load_safetensors(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) + + test_file = os.path.join(self.test_dir, "test.safetensors") with self.assertRaises(Exception): - mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0}) + mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0}) mx.save_safetensors( - "test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} + test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} ) - res = mx.load("test.safetensors", return_metadata=True) + res = mx.load(test_file, return_metadata=True) self.assertEqual(len(res), 2) self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 775a850a02..d63c1f2cdf 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -791,13 +791,13 @@ TEST_CASE("test reduction ops") { constexpr float inf = std::numeric_limits::infinity(); x = array({-inf, -inf}); - WARN_EQ(logsumexp(x).item(), -inf); + CHECK_EQ(logsumexp(x).item(), -inf); x = array({0.0f, -inf}); CHECK_EQ(logsumexp(x).item(), 0.0f); x = array({0.0f, inf}); - WARN_EQ(logsumexp(x).item(), inf); + CHECK_EQ(logsumexp(x).item(), inf); x = reshape(arange(6, float32), {2, 3}); @@ -2819,4 +2819,4 @@ TEST_CASE("test atleast_3d") { out = atleast_3d(x); CHECK_EQ(out.ndim(), 3); CHECK_EQ(out.shape(), std::vector{3, 1, 1}); -} \ No newline at end of file +}