mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
395 lines
26 KiB
C++
395 lines
26 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include "mlx/backend/metal/kernels/bf16.h"
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Metal math for bfloat16
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
/*
|
|
|
|
Following the Metal Shading Language Specification (Metal 3.1)
|
|
|
|
"bfloat is an extended itypeing point type that only allows implicit conversion
|
|
to a type of greater itypeing point rank. While bfloat can be implicitly
|
|
converted to itype, it cannot be implicitly converted to half, and neither
|
|
itype nor half can be implicitly converted to bfloat."
|
|
|
|
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
|
for bfloat and calling with an argument of type bfloat will result in that
|
|
argument getting implicitly converted to itype which then returns an output
|
|
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
|
|
|
This leads to situations where
|
|
bfloat a = 5.0bf;
|
|
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
|
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
|
|
|
For the moment, I will be adding overloaded instantiations of the math
|
|
functions to accordingly automatically handle the casting
|
|
|
|
*/
|
|
|
|
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
|
\
|
|
METAL_FUNC otype abs(itype x) { \
|
|
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype acos(itype x) { \
|
|
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype acosh(itype x) { \
|
|
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype asin(itype x) { \
|
|
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype asinh(itype x) { \
|
|
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype atan(itype y_over_x) { \
|
|
return static_cast<otype>( \
|
|
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype atan2(itype y, itype x) { \
|
|
return static_cast<otype>( \
|
|
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype atanh(itype x) { \
|
|
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype ceil(itype x) { \
|
|
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype cos(itype x) { \
|
|
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype cosh(itype x) { \
|
|
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype cospi(itype x) { \
|
|
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype divide(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype exp(itype x) { \
|
|
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype exp10(itype x) { \
|
|
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype exp2(itype x) { \
|
|
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype fabs(itype x) { \
|
|
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype fdim(itype x, itype y) { \
|
|
ctype t = static_cast<ctype>(x - y); \
|
|
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
|
} \
|
|
METAL_FUNC otype floor(itype x) { \
|
|
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
|
return static_cast<otype>(__metal_fma( \
|
|
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
|
} \
|
|
METAL_FUNC otype fmax(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
|
return static_cast<otype>(__metal_fmax3( \
|
|
static_cast<ctype>(x), \
|
|
static_cast<ctype>(y), \
|
|
static_cast<ctype>(z), \
|
|
mfast)); \
|
|
} \
|
|
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
|
return static_cast<otype>(__metal_fmedian3( \
|
|
static_cast<ctype>(x), \
|
|
static_cast<ctype>(y), \
|
|
static_cast<ctype>(z), \
|
|
mfast)); \
|
|
} \
|
|
METAL_FUNC otype fmin(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
|
return static_cast<otype>(__metal_fmin3( \
|
|
static_cast<ctype>(x), \
|
|
static_cast<ctype>(y), \
|
|
static_cast<ctype>(z), \
|
|
mfast)); \
|
|
} \
|
|
METAL_FUNC otype fmod(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype fract(itype x) { \
|
|
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
|
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
|
} \
|
|
METAL_FUNC otype ldexp(itype x, int k) { \
|
|
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
|
} \
|
|
METAL_FUNC otype log(itype x) { \
|
|
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype log10(itype x) { \
|
|
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype log2(itype x) { \
|
|
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype max(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
|
return static_cast<otype>(__metal_fmax3( \
|
|
static_cast<ctype>(x), \
|
|
static_cast<ctype>(y), \
|
|
static_cast<ctype>(z), \
|
|
mfast)); \
|
|
} \
|
|
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
|
return static_cast<otype>(__metal_fmedian3( \
|
|
static_cast<ctype>(x), \
|
|
static_cast<ctype>(y), \
|
|
static_cast<ctype>(z), \
|
|
mfast)); \
|
|
} \
|
|
METAL_FUNC otype min(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
|
return static_cast<otype>(__metal_fmin3( \
|
|
static_cast<ctype>(x), \
|
|
static_cast<ctype>(y), \
|
|
static_cast<ctype>(z), \
|
|
mfast)); \
|
|
} \
|
|
METAL_FUNC otype nextafter(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
|
} \
|
|
METAL_FUNC otype pow(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype powr(itype x, itype y) { \
|
|
return static_cast<otype>( \
|
|
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
} \
|
|
METAL_FUNC otype rint(itype x) { \
|
|
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype round(itype x) { \
|
|
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype rsqrt(itype x) { \
|
|
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype sin(itype x) { \
|
|
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype sinh(itype x) { \
|
|
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype sinpi(itype x) { \
|
|
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype sqrt(itype x) { \
|
|
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype tan(itype x) { \
|
|
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype tanh(itype x) { \
|
|
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype tanpi(itype x) { \
|
|
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
|
} \
|
|
METAL_FUNC otype trunc(itype x) { \
|
|
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
|
}
|
|
|
|
namespace metal {
|
|
|
|
instantiate_metal_math_funcs(
|
|
bfloat16_t,
|
|
bfloat16_t,
|
|
float,
|
|
__METAL_MAYBE_FAST_MATH__);
|
|
|
|
namespace fast {
|
|
|
|
instantiate_metal_math_funcs(
|
|
bfloat16_t,
|
|
bfloat16_t,
|
|
float,
|
|
__METAL_FAST_MATH__);
|
|
|
|
} // namespace fast
|
|
|
|
namespace precise {
|
|
|
|
instantiate_metal_math_funcs(
|
|
bfloat16_t,
|
|
bfloat16_t,
|
|
float,
|
|
__METAL_PRECISE_MATH__);
|
|
|
|
} // namespace precise
|
|
|
|
} // namespace metal
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Metal simd for bfloat16
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
#define instantiate_metal_simd_comm_funcs( \
|
|
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
|
\
|
|
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
|
return ctype_to_otype( \
|
|
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
|
return ctype_to_otype( \
|
|
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
|
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
|
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
|
itype data, itype filling_data, ushort delta) { \
|
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
|
itype_to_ctype(data), \
|
|
itype_to_ctype(filling_data), \
|
|
delta, \
|
|
__metal_get_simdgroup_size(ushort()))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
|
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
|
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
|
itype data, itype filling_data, ushort delta) { \
|
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
|
itype_to_ctype(data), \
|
|
itype_to_ctype(filling_data), \
|
|
delta, \
|
|
__metal_get_simdgroup_size(ushort()))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
|
return ctype_to_otype( \
|
|
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
|
return ctype_to_otype( \
|
|
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
|
return ctype_to_otype( \
|
|
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
|
return ctype_to_otype( \
|
|
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
|
return ctype_to_otype( \
|
|
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
|
}
|
|
|
|
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
|
\
|
|
METAL_FUNC otype simd_max(itype data) { \
|
|
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_min(itype data) { \
|
|
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
|
return static_cast<otype>( \
|
|
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
|
return static_cast<otype>( \
|
|
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
|
return static_cast<otype>( \
|
|
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
|
return static_cast<otype>( \
|
|
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_product(itype data) { \
|
|
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_sum(itype data) { \
|
|
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
|
} \
|
|
\
|
|
METAL_FUNC otype simd_xor(itype data) { \
|
|
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
|
}
|
|
|
|
#ifndef METAL_3_0
|
|
|
|
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
|
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
|
|
|
#else
|
|
|
|
#define bfloat16_to_uint16(x) x.bits_
|
|
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
|
|
|
#endif
|
|
|
|
namespace metal {
|
|
|
|
instantiate_metal_simd_comm_funcs(
|
|
bfloat16_t,
|
|
bfloat16_t,
|
|
uint16_t,
|
|
bfloat16_to_uint16,
|
|
uint16_to_bfloat16);
|
|
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
|
|
|
} // namespace metal
|