Choose the right MLX bf16 for extensions (#1135)

* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
This commit is contained in:
Awni Hannun 2024-05-17 15:09:28 -07:00 committed by GitHub
parent b3ec792380
commit 23406c9e9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 5 deletions

View File

@ -99,7 +99,10 @@ jobs:
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
cd examples/extensions && python3.8 -m pip install . source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results: - store_test_results:
path: test-results path: test-results
- run: - run:

View File

@ -6,9 +6,7 @@
using namespace metal; using namespace metal;
// No support for less than metal 3.0 #if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
// anything greater has native bfloat
#ifndef METAL_3_0
typedef bfloat bfloat16_t; typedef bfloat bfloat16_t;

View File

@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \ return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
} }
#ifndef METAL_3_0 #if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x) #define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x) #define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)