MLX
Loading...
Searching...
No Matches
defines.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#if defined __METAL__ || defined MLX_METAL_JIT
6#define MTL_CONST constant
7#else
8#define MTL_CONST
9#endif
10
11static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
12static MTL_CONST constexpr int REDUCE_N_READS = 16;
13static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
14static MTL_CONST constexpr int RMS_N_READS = 4;
15static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
16
17// Instantiate a templated kernel.
18// Extra args are used as template parameters:
19// e.g. instantiate_kernel(binary_int, binary, a, b) ->
20// [[host_name(binary_int)]] [kernel] binary<a, b>
21#define instantiate_kernel(name, func, ...) \
22 template [[host_name( \
23 name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
static constexpr int MAX_REDUCE_SPECIALIZED_DIMS
Definition defines.h:11
static constexpr int REDUCE_N_READS
Definition defines.h:12
static constexpr int RMS_LOOPED_LIMIT
Definition defines.h:15
static constexpr int SOFTMAX_N_READS
Definition defines.h:13
#define MTL_CONST
Definition defines.h:8
static constexpr int RMS_N_READS
Definition defines.h:14