diff --git a/.circleci/config.yml b/.circleci/config.yml index e9eebff69..a2455aa19 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -99,7 +99,10 @@ jobs: - run: name: Build example extension command: | - cd examples/extensions && python3.8 -m pip install . + source env/bin/activate + cd examples/extensions + pip install -r requirements.txt + python setup.py build_ext -j8 - store_test_results: path: test-results - run: diff --git a/mlx/backend/metal/kernels/bf16.h b/mlx/backend/metal/kernels/bf16.h index 03c73f9c2..726b676bb 100644 --- a/mlx/backend/metal/kernels/bf16.h +++ b/mlx/backend/metal/kernels/bf16.h @@ -6,9 +6,7 @@ using namespace metal; -// No support for less than metal 3.0 -// anything greater has native bfloat -#ifndef METAL_3_0 +#if defined METAL_3_1 || (__METAL_VERSION__ >= 310) typedef bfloat bfloat16_t; diff --git a/mlx/backend/metal/kernels/bf16_math.h b/mlx/backend/metal/kernels/bf16_math.h index 929429bdd..8c48b8cfd 100644 --- a/mlx/backend/metal/kernels/bf16_math.h +++ b/mlx/backend/metal/kernels/bf16_math.h @@ -369,7 +369,7 @@ instantiate_metal_math_funcs( return static_cast(__metal_simd_xor(static_cast(data))); \ } -#ifndef METAL_3_0 +#if defined METAL_3_1 || (__METAL_VERSION__ >= 310) #define bfloat16_to_uint16(x) as_type(x) #define uint16_to_bfloat16(x) as_type(x)