2024-05-23 03:57:13 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
2024-02-23 07:10:48 +08:00
|
|
|
|
|
|
|
#include <metal_integer>
|
|
|
|
#include <metal_math>
|
|
|
|
|
2024-05-23 03:57:13 +08:00
|
|
|
// clang-format off
|
2024-04-30 22:18:09 +08:00
|
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
2024-05-23 03:57:13 +08:00
|
|
|
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
|
|
|
#include "mlx/backend/metal/kernels/ternary.h"
|
2024-02-23 07:10:48 +08:00
|
|
|
|
2024-12-13 00:59:45 +08:00
|
|
|
#define instantiate_ternary_all(op, tname, type) \
|
|
|
|
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
|
|
|
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
2025-01-04 03:52:17 +08:00
|
|
|
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
|
2024-12-13 00:59:45 +08:00
|
|
|
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
|
|
|
|
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
|
|
|
|
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
|
|
|
|
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
|
|
|
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
|
|
|
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
|
|
|
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
2024-06-13 05:22:12 +08:00
|
|
|
|
|
|
|
#define instantiate_ternary_types(op) \
|
|
|
|
instantiate_ternary_all(op, bool_, bool) \
|
|
|
|
instantiate_ternary_all(op, uint8, uint8_t) \
|
|
|
|
instantiate_ternary_all(op, uint16, uint16_t) \
|
|
|
|
instantiate_ternary_all(op, uint32, uint32_t) \
|
|
|
|
instantiate_ternary_all(op, uint64, uint64_t) \
|
|
|
|
instantiate_ternary_all(op, int8, int8_t) \
|
|
|
|
instantiate_ternary_all(op, int16, int16_t) \
|
|
|
|
instantiate_ternary_all(op, int32, int32_t) \
|
|
|
|
instantiate_ternary_all(op, int64, int64_t) \
|
|
|
|
instantiate_ternary_all(op, float16, half) \
|
|
|
|
instantiate_ternary_all(op, float32, float) \
|
|
|
|
instantiate_ternary_all(op, bfloat16, bfloat16_t) \
|
|
|
|
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on
|
|
|
|
|
|
|
|
instantiate_ternary_types(Select)
|