mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-03 15:51:15 +08:00
Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread * only for gpu * more detailed scatter error
This commit is contained in:
parent
7178ac0111
commit
863039da4c
@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
@ -38,7 +39,8 @@
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
@ -69,3 +71,4 @@
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
// clang-format on
|
||||
|
@ -2641,6 +2641,14 @@ array scatter(
|
||||
idx = astype(idx, dtype, s);
|
||||
}
|
||||
|
||||
// TODO, remove when scatter supports 64-bit outputs
|
||||
if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scatter] GPU scatter does not yet support " << a.dtype()
|
||||
<< " for the input or updates.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
inputs.insert(inputs.begin(), a);
|
||||
// TODO promote or cast?
|
||||
inputs.push_back(astype(updates, a.dtype(), s));
|
||||
|
Loading…
Reference in New Issue
Block a user