mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 22:04:45 +08:00
Use cexpf in Metal
This commit is contained in:
134
mlx/backend/metal/kernels/cexpf.h
Normal file
134
mlx/backend/metal/kernels/cexpf.h
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2008-2013 NVIDIA Corporation
|
||||
// Copyright © 2013 Filipe RNC Maia
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
||||
|
||||
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
||||
// can not be used in JIT.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
using ieee_float_shape_type = union {
|
||||
float value;
|
||||
uint32_t word;
|
||||
};
|
||||
|
||||
inline void get_float_word(thread uint32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline void get_float_word(thread int32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline void set_float_word(thread float& d, uint32_t i) {
|
||||
ieee_float_shape_type sf_u;
|
||||
sf_u.word = (i);
|
||||
(d) = sf_u.value;
|
||||
}
|
||||
|
||||
inline float frexp_expf(float x, thread int* expt) {
|
||||
const uint32_t k = 235;
|
||||
const float kln2 = 162.88958740F;
|
||||
|
||||
float exp_x;
|
||||
uint32_t hx;
|
||||
|
||||
exp_x = metal::exp(x - kln2);
|
||||
get_float_word(hx, exp_x);
|
||||
*expt = (hx >> 23) - (0x7f + 127) + k;
|
||||
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
||||
return exp_x;
|
||||
}
|
||||
|
||||
inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
|
||||
float x, y, exp_x, scale1, scale2;
|
||||
int ex_expt, half_expt;
|
||||
|
||||
x = z.real;
|
||||
y = z.imag;
|
||||
exp_x = frexp_expf(x, &ex_expt);
|
||||
expt += ex_expt;
|
||||
|
||||
half_expt = expt / 2;
|
||||
set_float_word(scale1, (0x7f + half_expt) << 23);
|
||||
half_expt = expt - half_expt;
|
||||
set_float_word(scale2, (0x7f + half_expt) << 23);
|
||||
|
||||
return complex64_t{
|
||||
metal::cos(y) * exp_x * scale1 * scale2,
|
||||
metal::sin(y) * exp_x * scale1 * scale2};
|
||||
}
|
||||
|
||||
inline complex64_t cexpf(const thread complex64_t& z) {
|
||||
float x, y, exp_x;
|
||||
uint32_t hx, hy;
|
||||
|
||||
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
||||
|
||||
x = z.real;
|
||||
y = z.imag;
|
||||
|
||||
get_float_word(hy, y);
|
||||
hy &= 0x7fffffff;
|
||||
|
||||
/* cexp(x + I 0) = exp(x) + I 0 */
|
||||
if (hy == 0) {
|
||||
return complex64_t{metal::exp(x), y};
|
||||
}
|
||||
get_float_word(hx, x);
|
||||
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
||||
if ((hx & 0x7fffffff) == 0) {
|
||||
return complex64_t{metal::cos(y), metal::sin(y)};
|
||||
}
|
||||
if (hy >= 0x7f800000) {
|
||||
if ((hx & 0x7fffffff) != 0x7f800000) {
|
||||
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
||||
return complex64_t{y - y, y - y};
|
||||
} else if (hx & 0x80000000) {
|
||||
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
||||
return complex64_t{0.0, 0.0};
|
||||
} else {
|
||||
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
||||
return complex64_t{x, y - y};
|
||||
}
|
||||
}
|
||||
|
||||
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
||||
/*
|
||||
* x is between 88.7 and 192, so we must scale to avoid
|
||||
* overflow in expf(x).
|
||||
*/
|
||||
return ldexp_cexpf(z, 0);
|
||||
} else {
|
||||
/*
|
||||
* Cases covered here:
|
||||
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
||||
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
||||
* - x = +-Inf (generated by exp())
|
||||
* - x = NaN (spurious inexact exception from y)
|
||||
*/
|
||||
exp_x = metal::exp(x);
|
||||
return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
|
||||
}
|
||||
}
|
@@ -5,6 +5,7 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/cexpf.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||
|
||||
@@ -178,8 +179,7 @@ struct Exp {
|
||||
return metal::precise::exp(x);
|
||||
};
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
return cexpf(x);
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user