mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix eye for larger matrices (#463)
* fix eye * fix scatter for <32bit (non native atomic) types * fix int overflow
This commit is contained in:
parent
c15fe3e61b
commit
a2ffea683a
@ -1,5 +1,6 @@
|
|||||||
set(
|
set(
|
||||||
HEADERS
|
HEADERS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||||
|
@ -38,49 +38,59 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
|
|||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC T
|
METAL_FUNC T
|
||||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||||
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
|
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void
|
||||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||||
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
|
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||||
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
|
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void
|
||||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||||
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
|
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_min_explicit(
|
||||||
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
|
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_max_explicit(
|
||||||
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
|
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_add_explicit(
|
||||||
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
|
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
|
||||||
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
T expected = mlx_atomic_load_explicit(object, offset);
|
T expected = mlx_atomic_load_explicit(object, offset);
|
||||||
while (!mlx_atomic_compare_exchange_weak_explicit(
|
while (!mlx_atomic_compare_exchange_weak_explicit(
|
||||||
object, &expected, val * expected, offset)) {
|
object, &expected, val * expected, offset)) {
|
||||||
@ -92,7 +102,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
|||||||
device mlx_atomic<T>* object,
|
device mlx_atomic<T>* object,
|
||||||
thread T* expected,
|
thread T* expected,
|
||||||
T val,
|
T val,
|
||||||
int offset) {
|
uint offset) {
|
||||||
return atomic_compare_exchange_weak_explicit(
|
return atomic_compare_exchange_weak_explicit(
|
||||||
&(object[offset].val),
|
&(object[offset].val),
|
||||||
expected,
|
expected,
|
||||||
@ -106,7 +116,7 @@ template <>
|
|||||||
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
|
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
|
||||||
device mlx_atomic<float>* object,
|
device mlx_atomic<float>* object,
|
||||||
float val,
|
float val,
|
||||||
int offset) {
|
uint offset) {
|
||||||
float expected = mlx_atomic_load_explicit(object, offset);
|
float expected = mlx_atomic_load_explicit(object, offset);
|
||||||
while (val < expected) {
|
while (val < expected) {
|
||||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||||
@ -121,7 +131,7 @@ template <>
|
|||||||
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
|
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
|
||||||
device mlx_atomic<float>* object,
|
device mlx_atomic<float>* object,
|
||||||
float val,
|
float val,
|
||||||
int offset) {
|
uint offset) {
|
||||||
float expected = mlx_atomic_load_explicit(object, offset);
|
float expected = mlx_atomic_load_explicit(object, offset);
|
||||||
while (val > expected) {
|
while (val > expected) {
|
||||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||||
@ -148,7 +158,7 @@ union uint_or_packed {
|
|||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
struct mlx_atomic_update_helper {
|
struct mlx_atomic_update_helper {
|
||||||
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
|
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
|
||||||
Op op;
|
Op op;
|
||||||
init.val[elem_offset] = op(update, init.val[elem_offset]);
|
init.val[elem_offset] = op(update, init.val[elem_offset]);
|
||||||
return init.bits;
|
return init.bits;
|
||||||
@ -159,9 +169,9 @@ template <typename T, typename Op>
|
|||||||
METAL_FUNC void mlx_atomic_update_and_store(
|
METAL_FUNC void mlx_atomic_update_and_store(
|
||||||
device mlx_atomic<T>* object,
|
device mlx_atomic<T>* object,
|
||||||
T update,
|
T update,
|
||||||
int offset) {
|
uint offset) {
|
||||||
int pack_offset = offset / packing_size<T>;
|
uint pack_offset = offset / packing_size<T>;
|
||||||
int elem_offset = offset % packing_size<T>;
|
uint elem_offset = offset % packing_size<T>;
|
||||||
|
|
||||||
mlx_atomic_update_helper<T, Op> helper;
|
mlx_atomic_update_helper<T, Op> helper;
|
||||||
uint_or_packed<T> expected;
|
uint_or_packed<T> expected;
|
||||||
@ -242,9 +252,9 @@ struct __Min {
|
|||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC T
|
METAL_FUNC T
|
||||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||||
int pack_offset = offset / sizeof(T);
|
uint pack_offset = offset / sizeof(T);
|
||||||
int elem_offset = offset % sizeof(T);
|
uint elem_offset = offset % sizeof(T);
|
||||||
uint_or_packed<T> packed_val;
|
uint_or_packed<T> packed_val;
|
||||||
packed_val.bits =
|
packed_val.bits =
|
||||||
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||||
@ -253,15 +263,17 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
|
|||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void
|
||||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||||
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
|
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||||
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
int pack_offset = offset / packing_size<T>;
|
T val,
|
||||||
int elem_offset = offset % packing_size<T>;
|
uint offset) {
|
||||||
|
uint pack_offset = offset / packing_size<T>;
|
||||||
|
uint elem_offset = offset % packing_size<T>;
|
||||||
uint_or_packed<T> identity;
|
uint_or_packed<T> identity;
|
||||||
identity.bits = __UINT32_MAX__;
|
identity.bits = __UINT32_MAX__;
|
||||||
identity.val[elem_offset] = val;
|
identity.val[elem_offset] = val;
|
||||||
@ -272,9 +284,9 @@ mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
|||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void
|
||||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||||
int pack_offset = offset / packing_size<T>;
|
uint pack_offset = offset / packing_size<T>;
|
||||||
int elem_offset = offset % packing_size<T>;
|
uint elem_offset = offset % packing_size<T>;
|
||||||
uint_or_packed<T> identity;
|
uint_or_packed<T> identity;
|
||||||
identity.bits = 0;
|
identity.bits = 0;
|
||||||
identity.val[elem_offset] = val;
|
identity.val[elem_offset] = val;
|
||||||
@ -284,26 +296,34 @@ mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_min_explicit(
|
||||||
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
|
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_max_explicit(
|
||||||
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
|
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_add_explicit(
|
||||||
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
|
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||||
METAL_FUNC void
|
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
|
||||||
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
|
device mlx_atomic<T>* object,
|
||||||
|
T val,
|
||||||
|
uint offset) {
|
||||||
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
|
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,7 +332,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
|||||||
device mlx_atomic<T>* object,
|
device mlx_atomic<T>* object,
|
||||||
thread uint* expected,
|
thread uint* expected,
|
||||||
uint val,
|
uint val,
|
||||||
int offset) {
|
uint offset) {
|
||||||
return atomic_compare_exchange_weak_explicit(
|
return atomic_compare_exchange_weak_explicit(
|
||||||
&(object[offset].val),
|
&(object[offset].val),
|
||||||
expected,
|
expected,
|
||||||
|
@ -173,8 +173,7 @@ template <typename T, typename IdxT, typename Op, int NIDX>
|
|||||||
auto out_offset = elem_to_loc(
|
auto out_offset = elem_to_loc(
|
||||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||||
|
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||||
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
||||||
|
@ -16,7 +16,7 @@ union bool4_or_uint {
|
|||||||
|
|
||||||
struct None {
|
struct None {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||||
mlx_atomic_store_explicit(out, val, offset);
|
mlx_atomic_store_explicit(out, val, offset);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -41,7 +41,7 @@ struct And {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
|
||||||
if (!val) {
|
if (!val) {
|
||||||
mlx_atomic_store_explicit(out, val, offset);
|
mlx_atomic_store_explicit(out, val, offset);
|
||||||
}
|
}
|
||||||
@ -68,8 +68,8 @@ struct Or {
|
|||||||
void atomic_update(
|
void atomic_update(
|
||||||
device mlx_atomic<unsigned int>* out,
|
device mlx_atomic<unsigned int>* out,
|
||||||
bool val,
|
bool val,
|
||||||
int elem_idx,
|
uint elem_idx,
|
||||||
int offset = 0) {
|
uint offset = 0) {
|
||||||
if (val) {
|
if (val) {
|
||||||
bool4_or_uint update;
|
bool4_or_uint update;
|
||||||
update.b = {false, false, false, false};
|
update.b = {false, false, false, false};
|
||||||
@ -78,7 +78,7 @@ struct Or {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
|
||||||
if (val) {
|
if (val) {
|
||||||
mlx_atomic_store_explicit(out, val, offset);
|
mlx_atomic_store_explicit(out, val, offset);
|
||||||
}
|
}
|
||||||
@ -105,7 +105,7 @@ struct Sum {
|
|||||||
static constexpr constant U init = U(0);
|
static constexpr constant U init = U(0);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||||
mlx_atomic_fetch_add_explicit(out, val, offset);
|
mlx_atomic_fetch_add_explicit(out, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,7 +125,7 @@ struct Prod {
|
|||||||
static constexpr constant U init = U(1);
|
static constexpr constant U init = U(1);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||||
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ struct Min {
|
|||||||
static constexpr constant U init = Limits<U>::max;
|
static constexpr constant U init = Limits<U>::max;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||||
mlx_atomic_fetch_min_explicit(out, val, offset);
|
mlx_atomic_fetch_min_explicit(out, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ struct Max {
|
|||||||
static constexpr constant U init = Limits<U>::min;
|
static constexpr constant U init = Limits<U>::min;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||||
mlx_atomic_fetch_max_explicit(out, val, offset);
|
mlx_atomic_fetch_max_explicit(out, val, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
18
mlx/ops.cpp
18
mlx/ops.cpp
@ -218,20 +218,20 @@ array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
|||||||
if (n <= 0 || m <= 0) {
|
if (n <= 0 || m <= 0) {
|
||||||
throw std::invalid_argument("N and M must be positive integers.");
|
throw std::invalid_argument("N and M must be positive integers.");
|
||||||
}
|
}
|
||||||
array result = zeros({n * m}, dtype, s);
|
array result = zeros({n, m}, dtype, s);
|
||||||
if (k >= m || -k >= n) {
|
if (k >= m || -k >= n) {
|
||||||
return reshape(result, {n, m}, s);
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
|
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
|
||||||
int start_index = (k >= 0) ? k : -k * m;
|
|
||||||
|
|
||||||
array diag_indices_array = arange(
|
std::vector<array> indices;
|
||||||
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
|
auto s1 = std::max(0, -k);
|
||||||
array ones_array = ones({diagonal_length, 1}, dtype, s);
|
auto s2 = std::max(0, k);
|
||||||
result = scatter(result, diag_indices_array, ones_array, 0, s);
|
indices.push_back(arange(s1, diagonal_length + s1, int32, s));
|
||||||
|
indices.push_back(arange(s2, diagonal_length + s2, int32, s));
|
||||||
return reshape(result, {n, m}, s);
|
array ones_array = ones({diagonal_length, 1, 1}, dtype, s);
|
||||||
|
return scatter(result, indices, ones_array, {0, 1}, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <iostream> // TODO
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
@ -509,13 +510,14 @@ TEST_CASE("test is inf") {
|
|||||||
array x(1.0f);
|
array x(1.0f);
|
||||||
CHECK_FALSE(isinf(x).item<bool>());
|
CHECK_FALSE(isinf(x).item<bool>());
|
||||||
|
|
||||||
array y(std::numeric_limits<double>::infinity());
|
auto inf = std::numeric_limits<float>::infinity();
|
||||||
|
array y(inf);
|
||||||
CHECK(isinf(y).item<bool>());
|
CHECK(isinf(y).item<bool>());
|
||||||
|
|
||||||
array z = identity(7);
|
array z = identity(7);
|
||||||
CHECK_FALSE(any(isinf(z)).item<bool>());
|
CHECK_FALSE(any(isinf(z)).item<bool>());
|
||||||
|
|
||||||
array w = array({1.0f, std::numeric_limits<double>::infinity(), 2.0f});
|
array w = array({1.0f, inf, 2.0f});
|
||||||
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>());
|
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>());
|
||||||
|
|
||||||
array a(1.0f, bfloat16);
|
array a(1.0f, bfloat16);
|
||||||
@ -524,10 +526,10 @@ TEST_CASE("test is inf") {
|
|||||||
array b(1.0f, float16);
|
array b(1.0f, float16);
|
||||||
CHECK_FALSE(isinf(b).item<bool>());
|
CHECK_FALSE(isinf(b).item<bool>());
|
||||||
|
|
||||||
array c(std::numeric_limits<double>::infinity(), bfloat16);
|
array c(inf, bfloat16);
|
||||||
CHECK(isinf(c).item<bool>());
|
CHECK(isinf(c).item<bool>());
|
||||||
|
|
||||||
array d(std::numeric_limits<double>::infinity(), float16);
|
array d(inf, float16);
|
||||||
CHECK(isinf(d).item<bool>());
|
CHECK(isinf(d).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1878,6 +1880,28 @@ TEST_CASE("test scatter") {
|
|||||||
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
|
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test scatter types") {
|
||||||
|
for (auto t : {bool_, uint8, uint16, int8, int16}) {
|
||||||
|
auto in = zeros({4, 4}, t);
|
||||||
|
auto inds = {arange(4), arange(4)};
|
||||||
|
auto updates = ones({4, 1, 1}, t);
|
||||||
|
auto out = scatter(in, inds, updates, {0, 1});
|
||||||
|
auto expected =
|
||||||
|
array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto t : {float16, bfloat16}) {
|
||||||
|
auto in = zeros({4, 4}, t);
|
||||||
|
auto inds = {arange(4), arange(4)};
|
||||||
|
auto updates = ones({4, 1, 1}, t);
|
||||||
|
auto out = scatter(in, inds, updates, {0, 1});
|
||||||
|
auto expected =
|
||||||
|
array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);
|
||||||
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test complex ops") {
|
TEST_CASE("test complex ops") {
|
||||||
// Creation ops
|
// Creation ops
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user