diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 3ec50fb67..bc3da3018 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -1,5 +1,6 @@ set( HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/atomic.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h ${CMAKE_CURRENT_SOURCE_DIR}/complex.h diff --git a/mlx/backend/metal/kernels/atomic.h b/mlx/backend/metal/kernels/atomic.h index 1c260d1b6..c0f4b9ed8 100644 --- a/mlx/backend/metal/kernels/atomic.h +++ b/mlx/backend/metal/kernels/atomic.h @@ -38,49 +38,59 @@ struct mlx_atomic>> { template , bool> = true> METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, int offset) { +mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); } template , bool> = true> METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, int offset) { +mlx_atomic_store_explicit(device mlx_atomic* object, T val, uint offset) { atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_and_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + uint offset) { atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void -mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, int offset) { +mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, uint offset) { atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_min_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + uint offset) { atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_max_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + uint offset) { atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_add_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + uint offset) { atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_mul_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + uint offset) { T expected = mlx_atomic_load_explicit(object, offset); while (!mlx_atomic_compare_exchange_weak_explicit( object, &expected, val * expected, offset)) { @@ -92,7 +102,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread T* expected, T val, - int offset) { + uint offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, @@ -106,7 +116,7 @@ template <> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, float val, - int offset) { + uint offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val < expected) { if (mlx_atomic_compare_exchange_weak_explicit( @@ -121,7 +131,7 @@ template <> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, float val, - int offset) { + uint offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val > expected) { if (mlx_atomic_compare_exchange_weak_explicit( @@ -148,7 +158,7 @@ union uint_or_packed { template struct mlx_atomic_update_helper { - uint operator()(uint_or_packed init, T update, int elem_offset) { + uint operator()(uint_or_packed init, T update, uint elem_offset) { Op op; init.val[elem_offset] = op(update, init.val[elem_offset]); return init.bits; @@ -159,9 +169,9 @@ template METAL_FUNC void mlx_atomic_update_and_store( device mlx_atomic* object, T update, - int offset) { - int pack_offset = offset / packing_size; - int elem_offset = offset % packing_size; + uint offset) { + uint pack_offset = offset / packing_size; + uint elem_offset = offset % packing_size; mlx_atomic_update_helper helper; uint_or_packed expected; @@ -242,9 +252,9 @@ struct __Min { template , bool> = true> METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, int offset) { - int pack_offset = offset / sizeof(T); - int elem_offset = offset % sizeof(T); +mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { + uint pack_offset = offset / sizeof(T); + uint elem_offset = offset % sizeof(T); uint_or_packed packed_val; packed_val.bits = atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); @@ -253,15 +263,17 @@ mlx_atomic_load_explicit(device mlx_atomic* object, int offset) { template , bool> = true> METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, int offset) { +mlx_atomic_store_explicit(device mlx_atomic* object, T val, uint offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_and_explicit(device mlx_atomic* object, T val, int offset) { - int pack_offset = offset / packing_size; - int elem_offset = offset % packing_size; +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + uint offset) { + uint pack_offset = offset / packing_size; + uint elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = __UINT32_MAX__; identity.val[elem_offset] = val; @@ -272,9 +284,9 @@ mlx_atomic_fetch_and_explicit(device mlx_atomic* object, T val, int offset) { template , bool> = true> METAL_FUNC void -mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, int offset) { - int pack_offset = offset / packing_size; - int elem_offset = offset % packing_size; +mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, uint offset) { + uint pack_offset = offset / packing_size; + uint elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = 0; identity.val[elem_offset] = val; @@ -284,26 +296,34 @@ mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, int offset) { } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_min_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + uint offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_max_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + uint offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_add_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + uint offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_mul_explicit(device mlx_atomic* object, T val, int offset) { +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + uint offset) { mlx_atomic_update_and_store>(object, val, offset); } @@ -312,11 +332,11 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread uint* expected, uint val, - int offset) { + uint offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, val, memory_order_relaxed, memory_order_relaxed); -} \ No newline at end of file +} diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal index daba39a47..82642d250 100644 --- a/mlx/backend/metal/kernels/indexing.metal +++ b/mlx/backend/metal/kernels/indexing.metal @@ -173,8 +173,7 @@ template auto out_offset = elem_to_loc( ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim); - - op.atomic_update(out + out_idx + out_offset, updates[upd_idx]); + op.atomic_update(out, updates[upd_idx], out_idx + out_offset); } #define instantiate_scatter4(name, type, ind_type, op_type, nindex) \ diff --git a/mlx/backend/metal/kernels/reduce.h b/mlx/backend/metal/kernels/reduce.h index 1d2b971b2..70701aebd 100644 --- a/mlx/backend/metal/kernels/reduce.h +++ b/mlx/backend/metal/kernels/reduce.h @@ -16,7 +16,7 @@ union bool4_or_uint { struct None { template - void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { mlx_atomic_store_explicit(out, val, offset); } }; @@ -41,7 +41,7 @@ struct And { } } - void atomic_update(device mlx_atomic* out, bool val, int offset = 0) { + void atomic_update(device mlx_atomic* out, bool val, uint offset = 0) { if (!val) { mlx_atomic_store_explicit(out, val, offset); } @@ -68,8 +68,8 @@ struct Or { void atomic_update( device mlx_atomic* out, bool val, - int elem_idx, - int offset = 0) { + uint elem_idx, + uint offset = 0) { if (val) { bool4_or_uint update; update.b = {false, false, false, false}; @@ -78,7 +78,7 @@ struct Or { } } - void atomic_update(device mlx_atomic* out, bool val, int offset = 0) { + void atomic_update(device mlx_atomic* out, bool val, uint offset = 0) { if (val) { mlx_atomic_store_explicit(out, val, offset); } @@ -105,7 +105,7 @@ struct Sum { static constexpr constant U init = U(0); template - void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { mlx_atomic_fetch_add_explicit(out, val, offset); } @@ -125,7 +125,7 @@ struct Prod { static constexpr constant U init = U(1); template - void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { mlx_atomic_fetch_mul_explicit(out, val, offset); } @@ -145,7 +145,7 @@ struct Min { static constexpr constant U init = Limits::max; template - void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { mlx_atomic_fetch_min_explicit(out, val, offset); } @@ -165,7 +165,7 @@ struct Max { static constexpr constant U init = Limits::min; template - void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { mlx_atomic_fetch_max_explicit(out, val, offset); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6840afb62..61983197a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -218,20 +218,20 @@ array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) { if (n <= 0 || m <= 0) { throw std::invalid_argument("N and M must be positive integers."); } - array result = zeros({n * m}, dtype, s); + array result = zeros({n, m}, dtype, s); if (k >= m || -k >= n) { - return reshape(result, {n, m}, s); + return result; } int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m); - int start_index = (k >= 0) ? k : -k * m; - array diag_indices_array = arange( - start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s); - array ones_array = ones({diagonal_length, 1}, dtype, s); - result = scatter(result, diag_indices_array, ones_array, 0, s); - - return reshape(result, {n, m}, s); + std::vector indices; + auto s1 = std::max(0, -k); + auto s2 = std::max(0, k); + indices.push_back(arange(s1, diagonal_length + s1, int32, s)); + indices.push_back(arange(s2, diagonal_length + s2, int32, s)); + array ones_array = ones({diagonal_length, 1, 1}, dtype, s); + return scatter(result, indices, ones_array, {0, 1}, s); } array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) { diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1c00f4541..2c4bacbbf 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1,5 +1,6 @@ // Copyright © 2023 Apple Inc. #include +#include // TODO #include #include "doctest/doctest.h" @@ -509,13 +510,14 @@ TEST_CASE("test is inf") { array x(1.0f); CHECK_FALSE(isinf(x).item()); - array y(std::numeric_limits::infinity()); + auto inf = std::numeric_limits::infinity(); + array y(inf); CHECK(isinf(y).item()); array z = identity(7); CHECK_FALSE(any(isinf(z)).item()); - array w = array({1.0f, std::numeric_limits::infinity(), 2.0f}); + array w = array({1.0f, inf, 2.0f}); CHECK(array_equal({false, true, false}, isinf(w)).item()); array a(1.0f, bfloat16); @@ -524,10 +526,10 @@ TEST_CASE("test is inf") { array b(1.0f, float16); CHECK_FALSE(isinf(b).item()); - array c(std::numeric_limits::infinity(), bfloat16); + array c(inf, bfloat16); CHECK(isinf(c).item()); - array d(std::numeric_limits::infinity(), float16); + array d(inf, float16); CHECK(isinf(d).item()); } @@ -1878,6 +1880,28 @@ TEST_CASE("test scatter") { CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item()); } +TEST_CASE("test scatter types") { + for (auto t : {bool_, uint8, uint16, int8, int16}) { + auto in = zeros({4, 4}, t); + auto inds = {arange(4), arange(4)}; + auto updates = ones({4, 1, 1}, t); + auto out = scatter(in, inds, updates, {0, 1}); + auto expected = + array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t); + CHECK(array_equal(out, expected).item()); + } + + for (auto t : {float16, bfloat16}) { + auto in = zeros({4, 4}, t); + auto inds = {arange(4), arange(4)}; + auto updates = ones({4, 1, 1}, t); + auto out = scatter(in, inds, updates, {0, 1}); + auto expected = + array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t); + CHECK(allclose(out, expected).item()); + } +} + TEST_CASE("test complex ops") { // Creation ops {