#include <metal_math>
#include "bf16.h"
#include "mlx/backend/metal/kernels/bf16_math.h"
#include "mlx/backend/metal/kernels/complex.h"
#include "mlx/backend/metal/kernels/defines.h"
Go to the source code of this file.
Classes | |
struct | Limits< U > |
struct | Limits< uint8_t > |
struct | Limits< uint16_t > |
struct | Limits< uint32_t > |
struct | Limits< uint64_t > |
struct | Limits< int8_t > |
struct | Limits< int16_t > |
struct | Limits< int32_t > |
struct | Limits< int64_t > |
struct | Limits< half > |
struct | Limits< float > |
struct | Limits< bfloat16_t > |
struct | Limits< bool > |
struct | Limits< complex64_t > |
struct | LoopedElemToLoc< DIM, OffsetT, General > |
struct | LoopedElemToLoc< 1, OffsetT, true > |
struct | LoopedElemToLoc< 1, OffsetT, false > |
struct | ConditionalType< condition, T, U > |
struct | ConditionalType< true, T, U > |
Macros | |
#define | instantiate_default_limit(type) |
#define | instantiate_float_limit(type) |
#define | MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") |
Typedefs | |
typedef half | float16_t |
Functions | |
template<typename IdxT = int64_t> | |
METAL_FUNC IdxT | elem_to_loc (IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim) |
template<typename IdxT = int64_t> | |
METAL_FUNC IdxT | elem_to_loc (uint3 elem, constant const int *shape, constant const int64_t *strides, int ndim) |
template<typename IdxT = int64_t> | |
METAL_FUNC IdxT | elem_to_loc_1 (uint elem, constant const int64_t &stride) |
template<typename IdxT = int64_t> | |
METAL_FUNC IdxT | elem_to_loc_2 (uint2 elem, constant const int64_t strides[2]) |
template<typename IdxT = int64_t> | |
METAL_FUNC IdxT | elem_to_loc_3 (uint3 elem, constant const int64_t strides[3]) |
template<typename IdxT = int64_t> | |
METAL_FUNC vec< IdxT, 2 > | elem_to_loc_2_nd (uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim) |
template<typename IdxT = int64_t> | |
METAL_FUNC vec< IdxT, 3 > | elem_to_loc_3_nd (uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, constant const int64_t *c_strides, int ndim) |
template<typename T, typename U> | |
T | ceildiv (T N, U M) |
Compute ceil((float)N/(float)M) | |
float | log1p (float x) |
bfloat16_t | log1p (bfloat16_t x) |
uint64_t | simd_shuffle_down (uint64_t data, uint16_t delta) |
int64_t | simd_shuffle_down (int64_t data, uint16_t delta) |
bool | simd_shuffle_down (bool data, uint16_t delta) |
complex64_t | simd_shuffle_down (complex64_t data, uint16_t delta) |
uint64_t | simd_shuffle_up (uint64_t data, uint16_t delta) |
int64_t | simd_shuffle_up (int64_t data, uint16_t delta) |
bool | simd_shuffle_up (bool data, uint16_t delta) |
complex64_t | simd_shuffle_up (complex64_t data, uint16_t delta) |
uint64_t | simd_shuffle_and_fill_up (uint64_t data, uint64_t filling, uint16_t delta) |
int64_t | simd_shuffle_and_fill_up (int64_t data, int64_t filling, uint16_t delta) |
bool | simd_shuffle_and_fill_up (bool data, bool filling, uint16_t delta) |
complex64_t | simd_shuffle_and_fill_up (complex64_t data, complex64_t filling, uint16_t delta) |
uint64_t | simd_shuffle (uint64_t data, uint16_t lane) |
int64_t | simd_shuffle (int64_t data, uint16_t lane) |
bool | simd_shuffle (bool data, uint16_t lane) |
complex64_t | simd_shuffle (complex64_t data, uint16_t lane) |
#define instantiate_default_limit | ( | type | ) |
#define instantiate_float_limit | ( | type | ) |
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") |
typedef half float16_t |
|
inline |
Compute ceil((float)N/(float)M)
METAL_FUNC IdxT elem_to_loc | ( | IdxT | elem, |
constant const int * | shape, | ||
constant const int64_t * | strides, | ||
int | ndim ) |
METAL_FUNC IdxT elem_to_loc | ( | uint3 | elem, |
constant const int * | shape, | ||
constant const int64_t * | strides, | ||
int | ndim ) |
METAL_FUNC IdxT elem_to_loc_1 | ( | uint | elem, |
constant const int64_t & | stride ) |
METAL_FUNC IdxT elem_to_loc_2 | ( | uint2 | elem, |
constant const int64_t | strides[2] ) |
METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd | ( | uint3 | elem, |
constant const int * | shape, | ||
constant const int64_t * | a_strides, | ||
constant const int64_t * | b_strides, | ||
int | ndim ) |
METAL_FUNC IdxT elem_to_loc_3 | ( | uint3 | elem, |
constant const int64_t | strides[3] ) |
METAL_FUNC vec< IdxT, 3 > elem_to_loc_3_nd | ( | uint3 | elem, |
constant const int * | shape, | ||
constant const int64_t * | a_strides, | ||
constant const int64_t * | b_strides, | ||
constant const int64_t * | c_strides, | ||
int | ndim ) |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |