mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 16:51:13 +08:00
Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread * only for gpu * more detailed scatter error
This commit is contained in:
parent
7178ac0111
commit
863039da4c
@ -19,14 +19,14 @@ template <typename T>
|
|||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
instantiate_arange(uint8, uint8_t)
|
instantiate_arange(uint8, uint8_t)
|
||||||
instantiate_arange(uint16, uint16_t)
|
instantiate_arange(uint16, uint16_t)
|
||||||
instantiate_arange(uint32, uint32_t)
|
instantiate_arange(uint32, uint32_t)
|
||||||
instantiate_arange(uint64, uint64_t)
|
instantiate_arange(uint64, uint64_t)
|
||||||
instantiate_arange(int8, int8_t)
|
instantiate_arange(int8, int8_t)
|
||||||
instantiate_arange(int16, int16_t)
|
instantiate_arange(int16, int16_t)
|
||||||
instantiate_arange(int32, int32_t)
|
instantiate_arange(int32, int32_t)
|
||||||
instantiate_arange(int64, int64_t)
|
instantiate_arange(int64, int64_t)
|
||||||
instantiate_arange(float16, half)
|
instantiate_arange(float16, half)
|
||||||
instantiate_arange(float32, float)
|
instantiate_arange(float32, float)
|
||||||
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
@ -8,64 +8,67 @@
|
|||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||||
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
|
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
|
||||||
inst_f(name, bfloat16, bfloat16_t, op)
|
inst_f(name, bfloat16, bfloat16_t, op)
|
||||||
|
|
||||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||||
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
|
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
|
||||||
inst_f(name, uint32, uint32_t, op)
|
inst_f(name, uint32, uint32_t, op)
|
||||||
|
|
||||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||||
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
|
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
|
||||||
inst_f(name, int32, int32_t, op)
|
inst_f(name, int32, int32_t, op)
|
||||||
|
|
||||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||||
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
|
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
|
||||||
|
|
||||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||||
|
|
||||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||||
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
|
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
|
||||||
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
|
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
|
||||||
|
|
||||||
// Special case for bool reductions
|
// Special case for bool reductions
|
||||||
#define instantiate_reduce_from_types_helper( \
|
#define instantiate_reduce_from_types_helper( \
|
||||||
inst_f, name, tname, itype, otype, op) \
|
inst_f, name, tname, itype, otype, op) \
|
||||||
inst_f(name##tname, itype, otype, op)
|
inst_f(name##tname, itype, otype, op)
|
||||||
|
|
||||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||||
instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, bool_, bool, otype, op) \
|
||||||
inst_f, name, uint8, uint8_t, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, uint8, uint8_t, otype, op) \
|
||||||
inst_f, name, uint16, uint16_t, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, uint16, uint16_t, otype, op) \
|
||||||
inst_f, name, uint32, uint32_t, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, uint32, uint32_t, otype, op) \
|
||||||
inst_f, name, int8, int8_t, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, int8, int8_t, otype, op) \
|
||||||
inst_f, name, int16, int16_t, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, int16, int16_t, otype, op) \
|
||||||
inst_f, name, int32, int32_t, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, int32, int32_t, otype, op) \
|
||||||
inst_f, name, int64, int64_t, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, int64, int64_t, otype, op) \
|
||||||
inst_f, name, float16, half, otype, op) \
|
instantiate_reduce_from_types_helper( \
|
||||||
instantiate_reduce_from_types_helper( \
|
inst_f, name, float16, half, otype, op) \
|
||||||
inst_f, \
|
instantiate_reduce_from_types_helper( \
|
||||||
name, \
|
inst_f, \
|
||||||
float32, \
|
name, \
|
||||||
float, \
|
float32, \
|
||||||
otype, \
|
float, \
|
||||||
op) \
|
otype, \
|
||||||
instantiate_reduce_from_types_helper( \
|
op) \
|
||||||
inst_f, \
|
instantiate_reduce_from_types_helper( \
|
||||||
name, \
|
inst_f, \
|
||||||
bfloat16, \
|
name, \
|
||||||
bfloat16_t, \
|
bfloat16, \
|
||||||
otype, \
|
bfloat16_t, \
|
||||||
op)
|
otype, \
|
||||||
|
op)
|
||||||
|
// clang-format on
|
||||||
|
@ -2641,6 +2641,14 @@ array scatter(
|
|||||||
idx = astype(idx, dtype, s);
|
idx = astype(idx, dtype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO, remove when scatter supports 64-bit outputs
|
||||||
|
if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scatter] GPU scatter does not yet support " << a.dtype()
|
||||||
|
<< " for the input or updates.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
inputs.insert(inputs.begin(), a);
|
inputs.insert(inputs.begin(), a);
|
||||||
// TODO promote or cast?
|
// TODO promote or cast?
|
||||||
inputs.push_back(astype(updates, a.dtype(), s));
|
inputs.push_back(astype(updates, a.dtype(), s));
|
||||||
|
Loading…
Reference in New Issue
Block a user