mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 15:11:14 +08:00
318 lines
12 KiB
C++
318 lines
12 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <metal_stdlib>
|
|
|
|
using namespace metal;
|
|
|
|
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
|
|
|
typedef bfloat bfloat16_t;
|
|
|
|
#else
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Helpers
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
|
|
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
|
// Check for nan
|
|
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
|
_fp_encoding_traits<float>::inf_mask) {
|
|
return uint16_t(as_type<uint32_t>(0x7FC0));
|
|
}
|
|
// Take bits
|
|
uint32_t float_bits = as_type<uint32_t>(x);
|
|
|
|
// Round to nearest even
|
|
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
|
|
|
// Take upper 16 bits
|
|
return float_bits >> 16;
|
|
}
|
|
|
|
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
|
// Upper 16 bits are the data and lower 16 bits are 0s
|
|
return as_type<float>((uint32_t)x << 16);
|
|
}
|
|
|
|
struct _MLX_BFloat16;
|
|
|
|
template <typename T>
|
|
static constexpr constant bool can_convert_to_bfloat =
|
|
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
|
|
|
template <typename T>
|
|
static constexpr constant bool can_convert_from_bfloat =
|
|
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Bfloat struct
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct _MLX_BFloat16 {
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Constructors
|
|
uint16_t bits_;
|
|
_MLX_BFloat16() thread = default;
|
|
_MLX_BFloat16() threadgroup = default;
|
|
_MLX_BFloat16() device = default;
|
|
_MLX_BFloat16() constant = default;
|
|
|
|
struct bits_to_bfloat_struct {};
|
|
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
|
return bits_to_bfloat_struct();
|
|
}
|
|
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
|
: bits_(bits) {}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Conversions to bfloat
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Conversions from bfloat
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
constexpr METAL_FUNC operator T() const thread {
|
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
constexpr METAL_FUNC operator T() const threadgroup {
|
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
constexpr METAL_FUNC operator T() const device {
|
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
constexpr METAL_FUNC operator T() const constant {
|
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Bfloat operators
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Unary ops
|
|
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
|
return -static_cast<float>(x);
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Binary operators
|
|
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
|
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
}
|
|
|
|
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
|
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
} \
|
|
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Arithmetic Operators
|
|
#define bfloat_binop(_op_, _operator_) \
|
|
bfloat_binop_base( \
|
|
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
|
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
|
|
|
bfloat_binop(+, operator+);
|
|
bfloat_binop(-, operator-);
|
|
bfloat_binop(*, operator*);
|
|
bfloat_binop(/, operator/);
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Comparison ops
|
|
#define bfloat_compop(__op__, __operator__) \
|
|
bfloat_binop_base( \
|
|
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
|
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
|
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
|
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
|
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
|
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
|
|
|
bfloat_compop(>, operator>);
|
|
bfloat_compop(<, operator<);
|
|
bfloat_compop(>=, operator>=);
|
|
bfloat_compop(<=, operator<=);
|
|
bfloat_compop(==, operator==);
|
|
bfloat_compop(!=, operator!=);
|
|
|
|
#undef bfloat_compop
|
|
#undef bfloat_binop_base
|
|
#undef bfloat_binop_helper
|
|
#undef bfloat_binop
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Inplace Operators
|
|
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
|
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
return lhs; \
|
|
} \
|
|
constexpr METAL_FUNC addr_space itype& __operator__( \
|
|
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
return lhs; \
|
|
}
|
|
|
|
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
|
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
|
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
|
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
|
|
|
#define bfloat_inplace_op(itype) \
|
|
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
|
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
|
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
|
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
|
|
|
bfloat_inplace_op(float);
|
|
bfloat_inplace_op(half);
|
|
bfloat_inplace_op(int16_t);
|
|
bfloat_inplace_op(int32_t);
|
|
bfloat_inplace_op(int64_t);
|
|
bfloat_inplace_op(uint16_t);
|
|
bfloat_inplace_op(uint32_t);
|
|
bfloat_inplace_op(uint64_t);
|
|
|
|
#undef bfloat_inplace_op_helper
|
|
#undef bfloat_inplace_op_addr_space_helper
|
|
#undef bfloat_inplace_op
|
|
|
|
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
|
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
return lhs; \
|
|
}
|
|
|
|
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
|
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
|
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
|
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
|
|
|
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
|
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
|
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
|
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
|
|
|
#undef bfloat_inplace_op_helper
|
|
#undef bfloat_inplace_op_addr_space_helper
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Bfloat typedef
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
|
|
typedef struct _MLX_BFloat16 bfloat16_t;
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// Bfloat numeric limits
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
|
|
#pragma METAL internals : enable
|
|
|
|
namespace metal {
|
|
|
|
template <>
|
|
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
|
static constexpr constant int digits = 8;
|
|
static constexpr constant int digits10 = 2;
|
|
static constexpr constant int max_digits10 = 4;
|
|
static constexpr constant int radix = 2;
|
|
static constexpr constant int min_exponent = -125;
|
|
static constexpr constant int min_exponent10 = -37;
|
|
static constexpr constant int max_exponent = 128;
|
|
static constexpr constant int max_exponent10 = 38;
|
|
|
|
static constexpr bfloat16_t min() {
|
|
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t lowest() {
|
|
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t max() {
|
|
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t epsilon() {
|
|
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t round_error() {
|
|
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t infinity() {
|
|
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t quiet_NaN() {
|
|
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t signaling_NaN() {
|
|
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
static constexpr bfloat16_t denorm_min() {
|
|
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
|
}
|
|
};
|
|
|
|
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
|
return x != x;
|
|
}
|
|
|
|
} // namespace metal
|
|
|
|
#pragma METAL internals : disable
|
|
|
|
#endif
|
|
|
|
#include "mlx/backend/metal/kernels/bf16_math.h"
|