mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Choose the right MLX bf16 for extensions (#1135)
* default to custom bf * choose right bf * fix extensions * fix circle conf
This commit is contained in:
parent
b3ec792380
commit
23406c9e9e
@ -99,7 +99,10 @@ jobs:
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
cd examples/extensions && python3.8 -m pip install .
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
|
@ -6,9 +6,7 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// No support for less than metal 3.0
|
||||
// anything greater has native bfloat
|
||||
#ifndef METAL_3_0
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
|
@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#ifndef METAL_3_0
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
Loading…
Reference in New Issue
Block a user