mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-11 23:14:50 +08:00
only build for macos 14 and up (#2731)
* only build for macos 14 and up * bump metal cpp
This commit is contained in:
7
.github/actions/build-macos/action.yml
vendored
7
.github/actions/build-macos/action.yml
vendored
@@ -107,11 +107,6 @@ runs:
|
|||||||
-v python/tests \
|
-v python/tests \
|
||||||
-o test-results/gpu_jit
|
-o test-results/gpu_jit
|
||||||
|
|
||||||
- name: Build macOS 13 package
|
|
||||||
if: inputs.build-type == 'release'
|
|
||||||
uses: ./.github/actions/build-macos-release
|
|
||||||
with:
|
|
||||||
macos-target: 13.0
|
|
||||||
- name: Build macOS 14 package
|
- name: Build macOS 14 package
|
||||||
if: inputs.build-type == 'release'
|
if: inputs.build-type == 'release'
|
||||||
uses: ./.github/actions/build-macos-release
|
uses: ./.github/actions/build-macos-release
|
||||||
@@ -121,4 +116,4 @@ runs:
|
|||||||
if: inputs.build-type == 'release'
|
if: inputs.build-type == 'release'
|
||||||
uses: ./.github/actions/build-macos-release
|
uses: ./.github/actions/build-macos-release
|
||||||
with:
|
with:
|
||||||
macos-target: 15.0
|
macos-target: 15.0
|
||||||
|
|||||||
3
.github/workflows/nightly.yml
vendored
3
.github/workflows/nightly.yml
vendored
@@ -53,8 +53,7 @@ jobs:
|
|||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10", "3.13"]
|
python-version: ["3.10", "3.14"]
|
||||||
# TODO: 3.14 had issues finding a compatible tensorflow
|
|
||||||
env:
|
env:
|
||||||
MACOSX_DEPLOYMENT_TARGET: "15.0"
|
MACOSX_DEPLOYMENT_TARGET: "15.0"
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
|
|||||||
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -73,8 +73,7 @@ jobs:
|
|||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
# TODO: 3.14 had issues finding a compatible tensorflow
|
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
env:
|
env:
|
||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
|
|||||||
@@ -127,9 +127,12 @@ if(MLX_BUILD_METAL)
|
|||||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||||
|
|
||||||
set(METAL_CPP_URL
|
set(METAL_CPP_URL
|
||||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
||||||
|
|
||||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
|
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
||||||
|
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
||||||
|
endif()
|
||||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
endif()
|
endif()
|
||||||
execute_process(
|
execute_process(
|
||||||
@@ -138,7 +141,6 @@ if(MLX_BUILD_METAL)
|
|||||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
|
|||||||
@@ -17,11 +17,10 @@ To install from PyPI your system must meet the following requirements:
|
|||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.10
|
- Using a native Python >= 3.10
|
||||||
- macOS >= 13.5
|
- macOS >= 14.0
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
MLX is only available on devices running macOS >= 13.5
|
MLX is only available on devices running macOS >= 14.0 and higher.
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
|
||||||
|
|
||||||
CUDA
|
CUDA
|
||||||
^^^^
|
^^^^
|
||||||
|
|||||||
@@ -21,14 +21,8 @@ function(make_jit_source SRC_FILE)
|
|||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
|
||||||
endfunction(make_jit_source)
|
endfunction(make_jit_source)
|
||||||
|
|
||||||
make_jit_source(
|
make_jit_source(utils kernels/bf16.h kernels/bf16_math.h kernels/complex.h
|
||||||
utils
|
kernels/defines.h)
|
||||||
kernels/jit/bf16.h
|
|
||||||
kernels/metal_3_0/bf16.h
|
|
||||||
kernels/metal_3_1/bf16.h
|
|
||||||
kernels/bf16_math.h
|
|
||||||
kernels/complex.h
|
|
||||||
kernels/defines.h)
|
|
||||||
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
|
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)
|
||||||
make_jit_source(binary_ops)
|
make_jit_source(binary_ops)
|
||||||
make_jit_source(ternary_ops)
|
make_jit_source(ternary_ops)
|
||||||
|
|||||||
@@ -21,12 +21,12 @@ constexpr const char* default_mtllib_path = METAL_PATH;
|
|||||||
|
|
||||||
auto get_metal_version() {
|
auto get_metal_version() {
|
||||||
auto get_metal_version_ = []() {
|
auto get_metal_version_ = []() {
|
||||||
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
if (__builtin_available(macOS 26, iOS 26, tvOS 26, visionOS 26, *)) {
|
||||||
|
return MTL::LanguageVersion4_0;
|
||||||
|
} else if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
|
||||||
return MTL::LanguageVersion3_2;
|
return MTL::LanguageVersion3_2;
|
||||||
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
|
|
||||||
return MTL::LanguageVersion3_1;
|
|
||||||
} else {
|
} else {
|
||||||
return MTL::LanguageVersion3_0;
|
return MTL::LanguageVersion3_1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
static auto metal_version_ = get_metal_version_();
|
static auto metal_version_ = get_metal_version_();
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
set(BASE_HEADERS
|
set(BASE_HEADERS
|
||||||
metal_3_1/bf16.h
|
bf16.h
|
||||||
metal_3_0/bf16.h
|
|
||||||
bf16_math.h
|
bf16_math.h
|
||||||
complex.h
|
complex.h
|
||||||
defines.h
|
defines.h
|
||||||
@@ -18,16 +17,9 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
|||||||
set(METAL_FLAGS ${METAL_FLAGS}
|
set(METAL_FLAGS ${METAL_FLAGS}
|
||||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
endif()
|
endif()
|
||||||
if(MLX_METAL_VERSION GREATER_EQUAL 310)
|
|
||||||
set(VERSION_INCLUDES
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
|
|
||||||
else()
|
|
||||||
set(VERSION_INCLUDES
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
|
|
||||||
endif()
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||||
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
|
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||||
OUTPUT ${TARGET}.air
|
OUTPUT ${TARGET}.air
|
||||||
COMMENT "Building ${TARGET}.air"
|
COMMENT "Building ${TARGET}.air"
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define jit_if #if
|
|
||||||
#define jit_else #else
|
|
||||||
#define jit_endif #endif
|
|
||||||
|
|
||||||
jit_if (__METAL_VERSION__ >= 310)
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
|
|
||||||
|
|
||||||
jit_else
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
|
|
||||||
|
|
||||||
jit_endif // clang-format on
|
|
||||||
@@ -1,314 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Helpers
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
|
||||||
// Check for nan
|
|
||||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
|
||||||
_fp_encoding_traits<float>::inf_mask) {
|
|
||||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
|
||||||
}
|
|
||||||
// Take bits
|
|
||||||
uint32_t float_bits = as_type<uint32_t>(x);
|
|
||||||
|
|
||||||
// Round to nearest even
|
|
||||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
|
||||||
|
|
||||||
// Take upper 16 bits
|
|
||||||
return float_bits >> 16;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
|
||||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
|
||||||
return as_type<float>((uint32_t)x << 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct _MLX_BFloat16;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_to_bfloat =
|
|
||||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_from_bfloat =
|
|
||||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat struct
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
struct _MLX_BFloat16 {
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Constructors
|
|
||||||
uint16_t bits_;
|
|
||||||
_MLX_BFloat16() thread = default;
|
|
||||||
_MLX_BFloat16() threadgroup = default;
|
|
||||||
_MLX_BFloat16() device = default;
|
|
||||||
_MLX_BFloat16() constant = default;
|
|
||||||
|
|
||||||
struct bits_to_bfloat_struct {};
|
|
||||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
|
||||||
return bits_to_bfloat_struct();
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
|
||||||
: bits_(bits) {}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Conversions to bfloat
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Conversions from bfloat
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const thread {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const threadgroup {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const device {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const constant {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat operators
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Unary ops
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
|
||||||
return -static_cast<float>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Binary operators
|
|
||||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
|
||||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
|
||||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
} \
|
|
||||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Arithmetic Operators
|
|
||||||
#define bfloat_binop(_op_, _operator_) \
|
|
||||||
bfloat_binop_base( \
|
|
||||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
|
||||||
|
|
||||||
bfloat_binop(+, operator+);
|
|
||||||
bfloat_binop(-, operator-);
|
|
||||||
bfloat_binop(*, operator*);
|
|
||||||
bfloat_binop(/, operator/);
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Comparison ops
|
|
||||||
#define bfloat_compop(__op__, __operator__) \
|
|
||||||
bfloat_binop_base( \
|
|
||||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
|
||||||
|
|
||||||
bfloat_compop(>, operator>);
|
|
||||||
bfloat_compop(<, operator<);
|
|
||||||
bfloat_compop(>=, operator>=);
|
|
||||||
bfloat_compop(<=, operator<=);
|
|
||||||
bfloat_compop(==, operator==);
|
|
||||||
bfloat_compop(!=, operator!=);
|
|
||||||
|
|
||||||
#undef bfloat_compop
|
|
||||||
#undef bfloat_binop_base
|
|
||||||
#undef bfloat_binop_helper
|
|
||||||
#undef bfloat_binop
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Inplace Operators
|
|
||||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
|
||||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
||||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
} \
|
|
||||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
|
||||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
|
||||||
|
|
||||||
#define bfloat_inplace_op(itype) \
|
|
||||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
|
||||||
|
|
||||||
bfloat_inplace_op(float);
|
|
||||||
bfloat_inplace_op(half);
|
|
||||||
bfloat_inplace_op(int16_t);
|
|
||||||
bfloat_inplace_op(int32_t);
|
|
||||||
bfloat_inplace_op(int64_t);
|
|
||||||
bfloat_inplace_op(uint16_t);
|
|
||||||
bfloat_inplace_op(uint32_t);
|
|
||||||
bfloat_inplace_op(uint64_t);
|
|
||||||
|
|
||||||
#undef bfloat_inplace_op_helper
|
|
||||||
#undef bfloat_inplace_op_addr_space_helper
|
|
||||||
#undef bfloat_inplace_op
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
|
||||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
||||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
|
||||||
|
|
||||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
|
||||||
|
|
||||||
#undef bfloat_inplace_op_helper
|
|
||||||
#undef bfloat_inplace_op_addr_space_helper
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat typedef
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat numeric limits
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#pragma METAL internals : enable
|
|
||||||
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
|
||||||
static constexpr constant int digits = 8;
|
|
||||||
static constexpr constant int digits10 = 2;
|
|
||||||
static constexpr constant int max_digits10 = 4;
|
|
||||||
static constexpr constant int radix = 2;
|
|
||||||
static constexpr constant int min_exponent = -125;
|
|
||||||
static constexpr constant int min_exponent10 = -37;
|
|
||||||
static constexpr constant int max_exponent = 128;
|
|
||||||
static constexpr constant int max_exponent10 = 38;
|
|
||||||
|
|
||||||
static constexpr bfloat16_t min() {
|
|
||||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t lowest() {
|
|
||||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t max() {
|
|
||||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t epsilon() {
|
|
||||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t round_error() {
|
|
||||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t infinity() {
|
|
||||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t quiet_NaN() {
|
|
||||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t signaling_NaN() {
|
|
||||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t denorm_min() {
|
|
||||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
|
||||||
return x != x;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace metal
|
|
||||||
|
|
||||||
#pragma METAL internals : disable
|
|
||||||
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
|
||||||
return x.bits_;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
|
||||||
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
@@ -4,11 +4,7 @@
|
|||||||
|
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
// The correct bf16.h is included based on the metal version
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
// by giving the correct path to -I during compilation
|
|
||||||
// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
|
|
||||||
#include "bf16.h"
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16_math.h"
|
#include "mlx/backend/metal/kernels/bf16_math.h"
|
||||||
#include "mlx/backend/metal/kernels/complex.h"
|
#include "mlx/backend/metal/kernels/complex.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ bool is_available() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void start_capture(std::string path, id object) {
|
void start_capture(std::string path, NS::Object* object) {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
import platform
|
|
||||||
|
|
||||||
if platform.system() == "Darwin":
|
|
||||||
version = tuple(map(int, platform.mac_ver()[0].split(".")))
|
|
||||||
major, minor = version[0], version[1]
|
|
||||||
if (major, minor) < (13, 5):
|
|
||||||
raise ImportError(
|
|
||||||
f"Only macOS 13.5 and newer are supported, not {major}.{minor}"
|
|
||||||
)
|
|
||||||
@@ -28,7 +28,6 @@ NB_MODULE(core, m) {
|
|||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
|
|
||||||
auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix");
|
auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix");
|
||||||
nb::module_::import_("mlx._os_warning");
|
|
||||||
nb::set_leak_warnings(false);
|
nb::set_leak_warnings(false);
|
||||||
|
|
||||||
init_mlx_func(m);
|
init_mlx_func(m);
|
||||||
|
|||||||
Reference in New Issue
Block a user