mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
76 lines
4.3 KiB
Metal
76 lines
4.3 KiB
Metal
// Copyright © 2024 Apple Inc.
|
|
|
|
// clang-format off
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
|
#include "mlx/backend/metal/kernels/copy.h"
|
|
|
|
#define instantiate_copy_work_per_thread(tname, itype, otype) \
|
|
instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
|
|
instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
|
|
|
|
#define instantiate_copy_base(tname, itype, otype) \
|
|
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
|
|
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
|
|
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
|
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
|
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
|
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
|
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
|
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
|
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
|
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
|
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
|
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
|
|
|
#define instantiate_copy_all(tname, itype, otype) \
|
|
instantiate_copy_base(tname, itype, otype) \
|
|
instantiate_copy_work_per_thread(tname, itype, otype)
|
|
|
|
#define instantiate_copy_same(tname, type) \
|
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
|
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
|
|
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
|
|
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
|
|
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
|
|
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
|
|
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) \
|
|
instantiate_kernel("gg1_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type, int) \
|
|
instantiate_kernel("gg2_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type, int) \
|
|
instantiate_kernel("gg3_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type, int) \
|
|
instantiate_kernel("ggn2_dynamic_copy" #tname, copy_gg_dynamic, type, type, 2, int) \
|
|
instantiate_kernel("gg1large_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type) \
|
|
instantiate_kernel("gg2large_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type) \
|
|
instantiate_kernel("gg3large_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type) \
|
|
instantiate_kernel("ggn4large_dynamic_copy" #tname, copy_gg_dynamic, type, type, 4)
|
|
|
|
#define instantiate_copy_itype(itname, itype) \
|
|
instantiate_copy_same(itname ##itname, itype) \
|
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
|
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
|
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
|
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
|
instantiate_copy_base(itname ##uint64, itype, uint64_t) \
|
|
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
|
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
|
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
|
instantiate_copy_base(itname ##int64, itype, int64_t) \
|
|
instantiate_copy_all(itname ##float16, itype, half) \
|
|
instantiate_copy_all(itname ##float32, itype, float) \
|
|
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
|
instantiate_copy_base(itname ##complex64, itype, complex64_t)
|
|
|
|
instantiate_copy_itype(bool_, bool)
|
|
instantiate_copy_itype(uint8, uint8_t)
|
|
instantiate_copy_itype(uint16, uint16_t)
|
|
instantiate_copy_itype(uint32, uint32_t)
|
|
instantiate_copy_itype(uint64, uint64_t)
|
|
instantiate_copy_itype(int8, int8_t)
|
|
instantiate_copy_itype(int16, int16_t)
|
|
instantiate_copy_itype(int32, int32_t)
|
|
instantiate_copy_itype(int64, int64_t)
|
|
instantiate_copy_itype(float16, half)
|
|
instantiate_copy_itype(float32, float)
|
|
instantiate_copy_itype(bfloat16, bfloat16_t)
|
|
instantiate_copy_itype(complex64, complex64_t) // clang-format on
|