Choose the right MLX bf16 for extensions (#1135)

* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
This commit is contained in:
Awni Hannun 2024-05-17 15:09:28 -07:00 committed by GitHub
parent b3ec792380
commit 23406c9e9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 5 deletions

View File

@ -99,7 +99,10 @@ jobs:
- run:
name: Build example extension
command: |
cd examples/extensions && python3.8 -m pip install .
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:

View File

@ -6,9 +6,7 @@
using namespace metal;
// No support for less than metal 3.0
// anything greater has native bfloat
#ifndef METAL_3_0
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;

View File

@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#ifndef METAL_3_0
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)