mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Choose the right MLX bf16 for extensions (#1135)
* default to custom bf * choose right bf * fix extensions * fix circle conf
This commit is contained in:
parent
b3ec792380
commit
23406c9e9e
@ -99,7 +99,10 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Build example extension
|
name: Build example extension
|
||||||
command: |
|
command: |
|
||||||
cd examples/extensions && python3.8 -m pip install .
|
source env/bin/activate
|
||||||
|
cd examples/extensions
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python setup.py build_ext -j8
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
- run:
|
- run:
|
||||||
|
@ -6,9 +6,7 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
// No support for less than metal 3.0
|
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||||
// anything greater has native bfloat
|
|
||||||
#ifndef METAL_3_0
|
|
||||||
|
|
||||||
typedef bfloat bfloat16_t;
|
typedef bfloat bfloat16_t;
|
||||||
|
|
||||||
|
@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
|||||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef METAL_3_0
|
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user