mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard * add scale * review comments --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -453,8 +453,10 @@ array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array hadamard_transform(
|
||||
const array& a,
|
||||
float scale /* = 1.0 */,
|
||||
std::optional<float> scale_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
|
||||
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
|
||||
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
|
||||
return array(
|
||||
a.shape(),
|
||||
|
||||
Reference in New Issue
Block a user