diff --git a/mlx/backend/metal/kernels/arange.metal b/mlx/backend/metal/kernels/arange.metal index b896e4226..c333414f5 100644 --- a/mlx/backend/metal/kernels/arange.metal +++ b/mlx/backend/metal/kernels/arange.metal @@ -19,14 +19,14 @@ template uint index [[thread_position_in_grid]]); // clang-format off -instantiate_arange(uint8, uint8_t) +instantiate_arange(uint8, uint8_t) instantiate_arange(uint16, uint16_t) -instantiate_arange(uint32, uint32_t) +instantiate_arange(uint32, uint32_t) instantiate_arange(uint64, uint64_t) -instantiate_arange(int8, int8_t) +instantiate_arange(int8, int8_t) instantiate_arange(int16, int16_t) instantiate_arange(int32, int32_t) instantiate_arange(int64, int64_t) instantiate_arange(float16, half) instantiate_arange(float32, float) -instantiate_arange(bfloat16, bfloat16_t) // clang-format on \ No newline at end of file +instantiate_arange(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/reduction/reduce_inst.h b/mlx/backend/metal/kernels/reduction/reduce_inst.h index 593db7e62..ff45290e7 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_inst.h +++ b/mlx/backend/metal/kernels/reduction/reduce_inst.h @@ -8,64 +8,67 @@ #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/reduction/ops.h" +// clang-format off #define instantiate_reduce_helper_floats(inst_f, name, op) \ inst_f(name, float16, half, op) inst_f(name, float32, float, op) \ - inst_f(name, bfloat16, bfloat16_t, op) + inst_f(name, bfloat16, bfloat16_t, op) #define instantiate_reduce_helper_uints(inst_f, name, op) \ inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \ - inst_f(name, uint32, uint32_t, op) + inst_f(name, uint32, uint32_t, op) #define instantiate_reduce_helper_ints(inst_f, name, op) \ inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \ - inst_f(name, int32, int32_t, op) + inst_f(name, int32, int32_t, op) #define instantiate_reduce_helper_64b(inst_f, name, op) \ inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op) #define instantiate_reduce_helper_types(inst_f, name, op) \ instantiate_reduce_helper_floats(inst_f, name, op) \ - instantiate_reduce_helper_uints(inst_f, name, op) \ - instantiate_reduce_helper_ints(inst_f, name, op) + instantiate_reduce_helper_uints(inst_f, name, op) \ + instantiate_reduce_helper_ints(inst_f, name, op) #define instantiate_reduce_ops(inst_f, type_f) \ type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \ - type_f(inst_f, min_, Min) type_f(inst_f, max_, Max) + type_f(inst_f, min_, Min) type_f(inst_f, max_, Max) // Special case for bool reductions #define instantiate_reduce_from_types_helper( \ inst_f, name, tname, itype, otype, op) \ - inst_f(name##tname, itype, otype, op) + inst_f(name##tname, itype, otype, op) -#define instantiate_reduce_from_types(inst_f, name, otype, op) \ - instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint8, uint8_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint16, uint16_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint32, uint32_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int8, int8_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int16, int16_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int32, int32_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int64, int64_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, float16, half, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, \ - name, \ - float32, \ - float, \ - otype, \ - op) \ - instantiate_reduce_from_types_helper( \ - inst_f, \ - name, \ - bfloat16, \ - bfloat16_t, \ - otype, \ - op) \ No newline at end of file +#define instantiate_reduce_from_types(inst_f, name, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, bool_, bool, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint8, uint8_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint16, uint16_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint32, uint32_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int8, int8_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int16, int16_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int32, int32_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int64, int64_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, float16, half, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, \ + name, \ + float32, \ + float, \ + otype, \ + op) \ + instantiate_reduce_from_types_helper( \ + inst_f, \ + name, \ + bfloat16, \ + bfloat16_t, \ + otype, \ + op) +// clang-format on diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6a468e771..83be15384 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2641,6 +2641,14 @@ array scatter( idx = astype(idx, dtype, s); } + // TODO, remove when scatter supports 64-bit outputs + if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) { + std::ostringstream msg; + msg << "[scatter] GPU scatter does not yet support " << a.dtype() + << " for the input or updates."; + throw std::invalid_argument(msg.str()); + } + inputs.insert(inputs.begin(), a); // TODO promote or cast? inputs.push_back(astype(updates, a.dtype(), s));