mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix complex power and print (#2286)
* fix complex power and print * fix complex matmul shape
This commit is contained in:
parent
fddb6933e1
commit
8402a2acf4
@ -194,6 +194,13 @@ struct Power {
|
|||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
if (base.y == 0 && base.x == 0) {
|
||||||
|
if (isnan(exp.x) || isnan(exp.y)) {
|
||||||
|
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
||||||
|
return make_cuFloatComplex(nan, nan);
|
||||||
|
}
|
||||||
|
return make_cuFloatComplex(0.0, 0.0);
|
||||||
|
}
|
||||||
auto x_theta = atan2f(base.y, base.x);
|
auto x_theta = atan2f(base.y, base.x);
|
||||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||||
|
@ -235,6 +235,13 @@ struct Power {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
if (x.real == 0 && x.imag == 0) {
|
||||||
|
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
|
||||||
|
auto nan = metal::numeric_limits<float>::quiet_NaN();
|
||||||
|
return {nan, nan};
|
||||||
|
}
|
||||||
|
return {0.0, 0.0};
|
||||||
|
}
|
||||||
auto x_theta = metal::atan2(x.imag, x.real);
|
auto x_theta = metal::atan2(x.imag, x.real);
|
||||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||||
|
51
mlx/ops.cpp
51
mlx/ops.cpp
@ -2847,21 +2847,6 @@ array matmul(
|
|||||||
"[matmul] Got 0 dimension input. Inputs must "
|
"[matmul] Got 0 dimension input. Inputs must "
|
||||||
"have at least one dimension.");
|
"have at least one dimension.");
|
||||||
}
|
}
|
||||||
if (a.ndim() == 1) {
|
|
||||||
// Insert a singleton dim in the beginning
|
|
||||||
a = expand_dims(a, 0, s);
|
|
||||||
}
|
|
||||||
if (b.ndim() == 1) {
|
|
||||||
// Insert a singleton dim at the end
|
|
||||||
b = expand_dims(b, 1, s);
|
|
||||||
}
|
|
||||||
if (a.shape(-1) != b.shape(-2)) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[matmul] Last dimension of first input with shape " << a.shape()
|
|
||||||
<< " must match second to last dimension of"
|
|
||||||
<< " second input with shape " << b.shape() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// complex matmul using Karatsuba's Algorithm
|
// complex matmul using Karatsuba's Algorithm
|
||||||
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
||||||
@ -2883,6 +2868,22 @@ array matmul(
|
|||||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (a.ndim() == 1) {
|
||||||
|
// Insert a singleton dim in the beginning
|
||||||
|
a = expand_dims(a, 0, s);
|
||||||
|
}
|
||||||
|
if (b.ndim() == 1) {
|
||||||
|
// Insert a singleton dim at the end
|
||||||
|
b = expand_dims(b, 1, s);
|
||||||
|
}
|
||||||
|
if (a.shape(-1) != b.shape(-2)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[matmul] Last dimension of first input with shape " << a.shape()
|
||||||
|
<< " must match second to last dimension of"
|
||||||
|
<< " second input with shape " << b.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
// Type promotion
|
// Type promotion
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
|
|
||||||
@ -4240,6 +4241,16 @@ array addmm(
|
|||||||
"have at least one dimension.");
|
"have at least one dimension.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type promotion
|
||||||
|
auto out_type = result_type(a, b, c);
|
||||||
|
|
||||||
|
if (out_type == complex64) {
|
||||||
|
return add(
|
||||||
|
multiply(matmul(a, b, s), array(alpha), s),
|
||||||
|
multiply(array(beta), c, s),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
if (a.ndim() == 1) {
|
if (a.ndim() == 1) {
|
||||||
// Insert a singleton dim in the beginning
|
// Insert a singleton dim in the beginning
|
||||||
a = expand_dims(a, 0, s);
|
a = expand_dims(a, 0, s);
|
||||||
@ -4257,16 +4268,6 @@ array addmm(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type promotion
|
|
||||||
auto out_type = result_type(a, b, c);
|
|
||||||
|
|
||||||
if (out_type == complex64) {
|
|
||||||
return add(
|
|
||||||
multiply(matmul(a, b, s), array(alpha), s),
|
|
||||||
multiply(array(beta), c, s),
|
|
||||||
s);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[addmm] Only real floating point types are supported but "
|
msg << "[addmm] Only real floating point types are supported but "
|
||||||
|
@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) {
|
|||||||
os << val;
|
os << val;
|
||||||
}
|
}
|
||||||
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||||
os << val;
|
os << val.real();
|
||||||
|
if (val.imag() >= 0 || std::isnan(val.imag())) {
|
||||||
|
os << "+" << val.imag() << "j";
|
||||||
|
} else {
|
||||||
|
os << "-" << -val.imag() << "j";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PrintFormatter& get_global_formatter() {
|
PrintFormatter& get_global_formatter() {
|
||||||
|
@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
c_np = np.matmul(np.array(a).T, b)
|
c_np = np.matmul(np.array(a).T, b)
|
||||||
self.assertTrue(np.allclose(c, c_np))
|
self.assertTrue(np.allclose(c, c_np))
|
||||||
|
|
||||||
|
# Check shapes
|
||||||
|
a = mx.random.normal((2, 3)).astype(mx.complex64)
|
||||||
|
b = mx.random.normal((3,))
|
||||||
|
self.assertEqual((a @ b).shape, (2,))
|
||||||
|
|
||||||
|
a = mx.random.normal((2, 3)).astype(mx.complex64)
|
||||||
|
b = mx.random.normal((3,))
|
||||||
|
c = mx.random.normal((2,))
|
||||||
|
self.assertEqual(mx.addmm(c, a, b).shape, (2,))
|
||||||
|
|
||||||
def test_complex_gemm(self):
|
def test_complex_gemm(self):
|
||||||
M = 16
|
M = 16
|
||||||
K = 50
|
K = 50
|
||||||
|
@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
||||||
|
|
||||||
|
def test_complex_power(self):
|
||||||
|
out = mx.power(mx.array(0j), 2)
|
||||||
|
self.assertEqual(out.item(), 0j)
|
||||||
|
|
||||||
|
out = mx.power(mx.array(0j), float("nan"))
|
||||||
|
self.assertTrue(mx.isnan(out))
|
||||||
|
|
||||||
|
|
||||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||||
def test_broadcast_shapes(self):
|
def test_broadcast_shapes(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user