Fix eye for larger matrices (#463)

* fix eye
* fix scatter for <32bit (non native atomic) types
* fix int overflow
This commit is contained in:
Awni Hannun
2024-01-16 00:51:24 -08:00
committed by GitHub
parent c15fe3e61b
commit a2ffea683a
6 changed files with 109 additions and 65 deletions

View File

@@ -1,5 +1,6 @@
set(
HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h

View File

@@ -38,49 +38,59 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
T expected = mlx_atomic_load_explicit(object, offset);
while (!mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val * expected, offset)) {
@@ -92,7 +102,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread T* expected,
T val,
int offset) {
uint offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
@@ -106,7 +116,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
device mlx_atomic<float>* object,
float val,
int offset) {
uint offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val < expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -121,7 +131,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
device mlx_atomic<float>* object,
float val,
int offset) {
uint offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val > expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -148,7 +158,7 @@ union uint_or_packed {
template <typename T, typename Op>
struct mlx_atomic_update_helper {
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
Op op;
init.val[elem_offset] = op(update, init.val[elem_offset]);
return init.bits;
@@ -159,9 +169,9 @@ template <typename T, typename Op>
METAL_FUNC void mlx_atomic_update_and_store(
device mlx_atomic<T>* object,
T update,
int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
mlx_atomic_update_helper<T, Op> helper;
uint_or_packed<T> expected;
@@ -242,9 +252,9 @@ struct __Min {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
int pack_offset = offset / sizeof(T);
int elem_offset = offset % sizeof(T);
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
uint pack_offset = offset / sizeof(T);
uint elem_offset = offset % sizeof(T);
uint_or_packed<T> packed_val;
packed_val.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
@@ -253,15 +263,17 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = __UINT32_MAX__;
identity.val[elem_offset] = val;
@@ -272,9 +284,9 @@ mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = 0;
identity.val[elem_offset] = val;
@@ -284,26 +296,34 @@ mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
}
@@ -312,11 +332,11 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread uint* expected,
uint val,
int offset) {
uint offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
val,
memory_order_relaxed,
memory_order_relaxed);
}
}

View File

@@ -173,8 +173,7 @@ template <typename T, typename IdxT, typename Op, int NIDX>
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
}
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \

View File

@@ -16,7 +16,7 @@ union bool4_or_uint {
struct None {
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
mlx_atomic_store_explicit(out, val, offset);
}
};
@@ -41,7 +41,7 @@ struct And {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
if (!val) {
mlx_atomic_store_explicit(out, val, offset);
}
@@ -68,8 +68,8 @@ struct Or {
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
int elem_idx,
int offset = 0) {
uint elem_idx,
uint offset = 0) {
if (val) {
bool4_or_uint update;
update.b = {false, false, false, false};
@@ -78,7 +78,7 @@ struct Or {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
if (val) {
mlx_atomic_store_explicit(out, val, offset);
}
@@ -105,7 +105,7 @@ struct Sum {
static constexpr constant U init = U(0);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
mlx_atomic_fetch_add_explicit(out, val, offset);
}
@@ -125,7 +125,7 @@ struct Prod {
static constexpr constant U init = U(1);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
mlx_atomic_fetch_mul_explicit(out, val, offset);
}
@@ -145,7 +145,7 @@ struct Min {
static constexpr constant U init = Limits<U>::max;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
mlx_atomic_fetch_min_explicit(out, val, offset);
}
@@ -165,7 +165,7 @@ struct Max {
static constexpr constant U init = Limits<U>::min;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
mlx_atomic_fetch_max_explicit(out, val, offset);
}

View File

@@ -218,20 +218,20 @@ array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
if (n <= 0 || m <= 0) {
throw std::invalid_argument("N and M must be positive integers.");
}
array result = zeros({n * m}, dtype, s);
array result = zeros({n, m}, dtype, s);
if (k >= m || -k >= n) {
return reshape(result, {n, m}, s);
return result;
}
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
int start_index = (k >= 0) ? k : -k * m;
array diag_indices_array = arange(
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
array ones_array = ones({diagonal_length, 1}, dtype, s);
result = scatter(result, diag_indices_array, ones_array, 0, s);
return reshape(result, {n, m}, s);
std::vector<array> indices;
auto s1 = std::max(0, -k);
auto s2 = std::max(0, k);
indices.push_back(arange(s1, diagonal_length + s1, int32, s));
indices.push_back(arange(s2, diagonal_length + s2, int32, s));
array ones_array = ones({diagonal_length, 1, 1}, dtype, s);
return scatter(result, indices, ones_array, {0, 1}, s);
}
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {