MLX
Loading...
Searching...
No Matches
reduce_inst.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_atomic>
6#include <metal_simdgroup>
7
10
11#define instantiate_reduce_helper_floats(inst_f, name, op) \
12 inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
13 inst_f(name, bfloat16, bfloat16_t, op)
14
15#define instantiate_reduce_helper_uints(inst_f, name, op) \
16 inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
17 inst_f(name, uint32, uint32_t, op)
18
19#define instantiate_reduce_helper_ints(inst_f, name, op) \
20 inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
21 inst_f(name, int32, int32_t, op)
22
23#define instantiate_reduce_helper_64b(inst_f, name, op) \
24 inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
25
26#define instantiate_reduce_helper_types(inst_f, name, op) \
27 instantiate_reduce_helper_floats(inst_f, name, op) \
28 instantiate_reduce_helper_uints(inst_f, name, op) \
29 instantiate_reduce_helper_ints(inst_f, name, op)
30
31#define instantiate_reduce_ops(inst_f, type_f) \
32 type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
33 type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
34
35// Special case for bool reductions
36#define instantiate_reduce_from_types_helper( \
37 inst_f, name, tname, itype, otype, op) \
38 inst_f(name##tname, itype, otype, op)
39
40#define instantiate_reduce_from_types(inst_f, name, otype, op) \
41 instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \
42 instantiate_reduce_from_types_helper( \
43 inst_f, name, uint8, uint8_t, otype, op) \
44 instantiate_reduce_from_types_helper( \
45 inst_f, name, uint16, uint16_t, otype, op) \
46 instantiate_reduce_from_types_helper( \
47 inst_f, name, uint32, uint32_t, otype, op) \
48 instantiate_reduce_from_types_helper( \
49 inst_f, name, int8, int8_t, otype, op) \
50 instantiate_reduce_from_types_helper( \
51 inst_f, name, int16, int16_t, otype, op) \
52 instantiate_reduce_from_types_helper( \
53 inst_f, name, int32, int32_t, otype, op) \
54 instantiate_reduce_from_types_helper( \
55 inst_f, name, int64, int64_t, otype, op) \
56 instantiate_reduce_from_types_helper( \
57 inst_f, name, float16, half, otype, op) \
58 instantiate_reduce_from_types_helper( \
59 inst_f, \
60 name, \
61 float32, \
62 float, \
63 otype, \
64 op) \
65 instantiate_reduce_from_types_helper( \
66 inst_f, \
67 name, \
68 bfloat16, \
69 bfloat16_t, \
70 otype, \
71 op)