2024-05-22 12:57:13 -07:00
|
|
|
// Copyright © 2024 Apple Inc.
|
2023-11-30 11:12:53 -08:00
|
|
|
|
2024-05-22 12:57:13 -07:00
|
|
|
#include <metal_integer>
|
|
|
|
|
#include <metal_math>
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2024-05-22 12:57:13 -07:00
|
|
|
// clang-format off
|
2024-06-12 14:22:12 -07:00
|
|
|
#include "mlx/backend/metal/kernels/defines.h"
|
2024-05-22 12:57:13 -07:00
|
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
|
|
|
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
|
|
|
|
#include "mlx/backend/metal/kernels/binary.h"
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2025-06-06 11:37:40 -07:00
|
|
|
#define instantiate_binary_work_per_thread(op, tname, itype, otype) \
|
|
|
|
|
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
|
|
|
|
|
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
|
|
|
|
|
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
|
|
|
|
|
|
|
|
|
|
#define instantiate_binary_base(op, tname, itype, otype) \
|
2024-12-12 08:59:45 -08:00
|
|
|
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
2025-06-06 11:37:40 -07:00
|
|
|
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
|
|
|
|
|
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
|
|
|
|
|
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
|
2024-12-12 08:59:45 -08:00
|
|
|
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
|
|
|
|
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
|
|
|
|
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
|
|
|
|
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
|
|
|
|
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
|
|
|
|
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
|
|
|
|
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
|
|
|
|
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
|
|
|
|
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
|
|
|
|
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
2024-11-18 19:52:00 -08:00
|
|
|
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2025-06-06 11:37:40 -07:00
|
|
|
#define instantiate_binary_all(op, tname, itype, otype) \
|
|
|
|
|
instantiate_binary_base(op, tname, itype, otype) \
|
|
|
|
|
instantiate_binary_work_per_thread(op, tname, itype, otype)
|
|
|
|
|
|
|
|
|
|
#define instantiate_binary_integer(op) \
|
|
|
|
|
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
|
|
|
|
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
|
|
|
|
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
|
|
|
|
instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
|
|
|
|
|
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
|
|
|
|
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
|
|
|
|
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
|
|
|
|
instantiate_binary_base(op, int64, int64_t, int64_t)
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2024-06-12 14:22:12 -07:00
|
|
|
#define instantiate_binary_float(op) \
|
|
|
|
|
instantiate_binary_all(op, float16, half, half) \
|
|
|
|
|
instantiate_binary_all(op, float32, float, float) \
|
|
|
|
|
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
2024-05-22 12:57:13 -07:00
|
|
|
|
2024-06-12 14:22:12 -07:00
|
|
|
#define instantiate_binary_types(op) \
|
|
|
|
|
instantiate_binary_all(op, bool_, bool, bool) \
|
|
|
|
|
instantiate_binary_integer(op) \
|
2025-06-06 11:37:40 -07:00
|
|
|
instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_float(op)
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2024-06-12 14:22:12 -07:00
|
|
|
#define instantiate_binary_types_bool(op) \
|
|
|
|
|
instantiate_binary_all(op, bool_, bool, bool) \
|
|
|
|
|
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
|
|
|
|
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
|
|
|
|
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
2025-06-06 11:37:40 -07:00
|
|
|
instantiate_binary_base(op, uint64, uint64_t, bool) \
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_all(op, int8, int8_t, bool) \
|
|
|
|
|
instantiate_binary_all(op, int16, int16_t, bool) \
|
|
|
|
|
instantiate_binary_all(op, int32, int32_t, bool) \
|
2025-06-06 11:37:40 -07:00
|
|
|
instantiate_binary_base(op, int64, int64_t, bool) \
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_all(op, float16, half, bool) \
|
|
|
|
|
instantiate_binary_all(op, float32, float, bool) \
|
|
|
|
|
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
2025-06-06 11:37:40 -07:00
|
|
|
instantiate_binary_base(op, complex64, complex64_t, bool)
|
2023-11-29 10:30:41 -08:00
|
|
|
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_types(Add)
|
|
|
|
|
instantiate_binary_types(Divide)
|
|
|
|
|
instantiate_binary_types_bool(Equal)
|
|
|
|
|
instantiate_binary_types_bool(Greater)
|
|
|
|
|
instantiate_binary_types_bool(GreaterEqual)
|
|
|
|
|
instantiate_binary_types_bool(Less)
|
|
|
|
|
instantiate_binary_types_bool(LessEqual)
|
|
|
|
|
instantiate_binary_types_bool(NotEqual)
|
|
|
|
|
instantiate_binary_float(LogAddExp)
|
2025-06-06 11:37:40 -07:00
|
|
|
instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_types(Maximum)
|
|
|
|
|
instantiate_binary_types(Minimum)
|
|
|
|
|
instantiate_binary_types(Multiply)
|
|
|
|
|
instantiate_binary_types(Subtract)
|
|
|
|
|
instantiate_binary_types(Power)
|
|
|
|
|
instantiate_binary_types(Remainder)
|
|
|
|
|
instantiate_binary_float(ArcTan2)
|
2023-11-29 10:30:41 -08:00
|
|
|
|
|
|
|
|
// NaNEqual only needed for floating point types with boolean output
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_all(NaNEqual, float16, half, bool)
|
|
|
|
|
instantiate_binary_all(NaNEqual, float32, float, bool)
|
|
|
|
|
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
2025-06-06 11:37:40 -07:00
|
|
|
instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
|
2024-01-08 19:00:05 +04:00
|
|
|
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
|
|
|
|
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
2024-04-26 22:03:42 -07:00
|
|
|
|
|
|
|
|
// Bitwise ops only need integer types and bool (except for l/r shift)
|
2024-06-12 14:22:12 -07:00
|
|
|
instantiate_binary_integer(BitwiseAnd)
|
|
|
|
|
instantiate_binary_all(BitwiseAnd, bool_, bool, bool)
|
|
|
|
|
instantiate_binary_integer(BitwiseOr)
|
|
|
|
|
instantiate_binary_all(BitwiseOr, bool_, bool, bool)
|
|
|
|
|
instantiate_binary_integer(BitwiseXor)
|
|
|
|
|
instantiate_binary_all(BitwiseXor, bool_, bool, bool)
|
|
|
|
|
instantiate_binary_integer(LeftShift)
|
|
|
|
|
instantiate_binary_integer(RightShift) // clang-format on
|