mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* Contiguous scan * Strided scan * Enable tests * Fix failing logaddexp test * Use cexpf in Metal
135 lines
3.6 KiB
C
135 lines
3.6 KiB
C
// Copyright © 2025 Apple Inc.
|
|
// Copyright © 2008-2013 NVIDIA Corporation
|
|
// Copyright © 2013 Filipe RNC Maia
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// Forked from
|
|
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
|
|
|
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
|
// can not be used in JIT.
|
|
|
|
#pragma once
|
|
|
|
#include <metal_math>
|
|
|
|
using ieee_float_shape_type = union {
|
|
float value;
|
|
uint32_t word;
|
|
};
|
|
|
|
inline void get_float_word(thread uint32_t& i, float d) {
|
|
ieee_float_shape_type gf_u;
|
|
gf_u.value = (d);
|
|
(i) = gf_u.word;
|
|
}
|
|
|
|
inline void get_float_word(thread int32_t& i, float d) {
|
|
ieee_float_shape_type gf_u;
|
|
gf_u.value = (d);
|
|
(i) = gf_u.word;
|
|
}
|
|
|
|
inline void set_float_word(thread float& d, uint32_t i) {
|
|
ieee_float_shape_type sf_u;
|
|
sf_u.word = (i);
|
|
(d) = sf_u.value;
|
|
}
|
|
|
|
inline float frexp_expf(float x, thread int* expt) {
|
|
const uint32_t k = 235;
|
|
const float kln2 = 162.88958740F;
|
|
|
|
float exp_x;
|
|
uint32_t hx;
|
|
|
|
exp_x = metal::exp(x - kln2);
|
|
get_float_word(hx, exp_x);
|
|
*expt = (hx >> 23) - (0x7f + 127) + k;
|
|
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
|
return exp_x;
|
|
}
|
|
|
|
inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
|
|
float x, y, exp_x, scale1, scale2;
|
|
int ex_expt, half_expt;
|
|
|
|
x = z.real;
|
|
y = z.imag;
|
|
exp_x = frexp_expf(x, &ex_expt);
|
|
expt += ex_expt;
|
|
|
|
half_expt = expt / 2;
|
|
set_float_word(scale1, (0x7f + half_expt) << 23);
|
|
half_expt = expt - half_expt;
|
|
set_float_word(scale2, (0x7f + half_expt) << 23);
|
|
|
|
return complex64_t{
|
|
metal::cos(y) * exp_x * scale1 * scale2,
|
|
metal::sin(y) * exp_x * scale1 * scale2};
|
|
}
|
|
|
|
inline complex64_t cexpf(const thread complex64_t& z) {
|
|
float x, y, exp_x;
|
|
uint32_t hx, hy;
|
|
|
|
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
|
|
|
x = z.real;
|
|
y = z.imag;
|
|
|
|
get_float_word(hy, y);
|
|
hy &= 0x7fffffff;
|
|
|
|
/* cexp(x + I 0) = exp(x) + I 0 */
|
|
if (hy == 0) {
|
|
return complex64_t{metal::exp(x), y};
|
|
}
|
|
get_float_word(hx, x);
|
|
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
|
if ((hx & 0x7fffffff) == 0) {
|
|
return complex64_t{metal::cos(y), metal::sin(y)};
|
|
}
|
|
if (hy >= 0x7f800000) {
|
|
if ((hx & 0x7fffffff) != 0x7f800000) {
|
|
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
|
return complex64_t{y - y, y - y};
|
|
} else if (hx & 0x80000000) {
|
|
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
|
return complex64_t{0.0, 0.0};
|
|
} else {
|
|
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
|
return complex64_t{x, y - y};
|
|
}
|
|
}
|
|
|
|
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
|
/*
|
|
* x is between 88.7 and 192, so we must scale to avoid
|
|
* overflow in expf(x).
|
|
*/
|
|
return ldexp_cexpf(z, 0);
|
|
} else {
|
|
/*
|
|
* Cases covered here:
|
|
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
|
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
|
* - x = +-Inf (generated by exp())
|
|
* - x = NaN (spurious inexact exception from y)
|
|
*/
|
|
exp_x = metal::exp(x);
|
|
return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
|
|
}
|
|
}
|