MLX
Loading...
Searching...
No Matches
reduce_inst.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_atomic>
6#include <metal_simdgroup>
7
10
11// clang-format off
12#define instantiate_reduce_helper_floats(inst_f, name, op) \
13 inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
14 inst_f(name, bfloat16, bfloat16_t, op)
15
16#define instantiate_reduce_helper_uints(inst_f, name, op) \
17 inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
18 inst_f(name, uint32, uint32_t, op)
19
20#define instantiate_reduce_helper_ints(inst_f, name, op) \
21 inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
22 inst_f(name, int32, int32_t, op)
23
24#define instantiate_reduce_helper_64b(inst_f, name, op) \
25 inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
26
27#define instantiate_reduce_helper_types(inst_f, name, op) \
28 instantiate_reduce_helper_floats(inst_f, name, op) \
29 instantiate_reduce_helper_uints(inst_f, name, op) \
30 instantiate_reduce_helper_ints(inst_f, name, op)
31
32#define instantiate_reduce_ops(inst_f, type_f) \
33 type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
34 type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
35
36// Special case for bool reductions
37#define instantiate_reduce_from_types_helper( \
38 inst_f, name, tname, itype, otype, op) \
39 inst_f(name##tname, itype, otype, op)
40
41#define instantiate_reduce_from_types(inst_f, name, otype, op) \
42 instantiate_reduce_from_types_helper( \
43 inst_f, name, bool_, bool, otype, op) \
44 instantiate_reduce_from_types_helper( \
45 inst_f, name, uint8, uint8_t, otype, op) \
46 instantiate_reduce_from_types_helper( \
47 inst_f, name, uint16, uint16_t, otype, op) \
48 instantiate_reduce_from_types_helper( \
49 inst_f, name, uint32, uint32_t, otype, op) \
50 instantiate_reduce_from_types_helper( \
51 inst_f, name, int8, int8_t, otype, op) \
52 instantiate_reduce_from_types_helper( \
53 inst_f, name, int16, int16_t, otype, op) \
54 instantiate_reduce_from_types_helper( \
55 inst_f, name, int32, int32_t, otype, op) \
56 instantiate_reduce_from_types_helper( \
57 inst_f, name, int64, int64_t, otype, op) \
58 instantiate_reduce_from_types_helper( \
59 inst_f, name, float16, half, otype, op) \
60 instantiate_reduce_from_types_helper( \
61 inst_f, \
62 name, \
63 float32, \
64 float, \
65 otype, \
66 op) \
67 instantiate_reduce_from_types_helper( \
68 inst_f, \
69 name, \
70 bfloat16, \
71 bfloat16_t, \
72 otype, \
73 op)
74// clang-format on