mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Fix concatenate/slice_update vjp + reduce binary size (#1735)
* fix concatenate vjp + reduce binary size * also cast in slice update
This commit is contained in:
@@ -12,21 +12,24 @@
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype)
|
||||
|
||||
#define instantiate_copy_same(tname, type) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, type, type, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_same(itname ##itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
|
Reference in New Issue
Block a user