mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 11:38:06 +08:00
25 lines
830 B
C
25 lines
830 B
C
// Copyright © 2023 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#if defined __METAL__ || defined MLX_METAL_JIT
|
|
#define MTL_CONST constant
|
|
#else
|
|
#define MTL_CONST
|
|
#endif
|
|
|
|
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
|
static MTL_CONST constexpr int REDUCE_N_READS = 4;
|
|
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
|
|
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
|
static MTL_CONST constexpr int RMS_N_READS = 4;
|
|
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
|
|
|
// Instantiate a templated kernel.
|
|
// Extra args are used as template parameters:
|
|
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
|
|
// [[host_name(binary_int)]] [kernel] binary<a, b>
|
|
#define instantiate_kernel(name, func, ...) \
|
|
template [[host_name( \
|
|
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
|