mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 06:24:35 +08:00

* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios * Mark float literals in softmax.cpp to be float16_t for errors in ios * Add arm neon vector operation guards * Redirect to common backend for consistency
28 lines
762 B
C++
28 lines
762 B
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <Accelerate/Accelerate.h>
|
|
#include "mlx/dtype.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
|
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
|
switch (kindof(mlx_dtype)) {
|
|
case Dtype::Kind::b:
|
|
return BNNSDataTypeBoolean;
|
|
case Dtype::Kind::u:
|
|
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
|
case Dtype::Kind::i:
|
|
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
|
case Dtype::Kind::f:
|
|
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
|
case Dtype::Kind::V:
|
|
return BNNSDataTypeBFloat16;
|
|
case Dtype::Kind::c:
|
|
throw std::invalid_argument("BNNS does not support complex types");
|
|
}
|
|
}
|
|
|
|
} // namespace mlx::core
|