diff --git a/mlx/array.h b/mlx/array.h index d9fcfc58e..98eef2e33 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -224,6 +224,10 @@ class array { // Not copyable Data(const Data& d) = delete; Data& operator=(const Data& d) = delete; + Data(Data&& o) : buffer(o.buffer), d(o.d) { + o.buffer = allocator::Buffer(nullptr); + o.d = [](allocator::Buffer) {}; + } ~Data() { d(buffer); } diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index b50f2c722..58d3634e8 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -12,6 +12,133 @@ namespace mlx::core { namespace { +template +struct EighWork {}; + +template +struct EighWork< + T, + typename std::enable_if::value>::type> { + using R = T; + + char jobz; + char uplo; + int N; + int lwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) { + T work; + int iwork; + syevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, T* values) { + syevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &liwork, + &info); + } +}; + +template <> +struct EighWork> { + using T = std::complex; + using R = float; + + char jobz; + char uplo; + int N; + int lwork; + int lrwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) { + T work; + R rwork; + int iwork; + heevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &rwork, + &lrwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work.real()); + lrwork = static_cast(rwork); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, R* values) { + heevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &lrwork, + static_cast(buffers[2].buffer.raw_ptr()), + &liwork, + &info); + if (jobz == 'V') { + // We have pre-transposed the vectors but we also must conjugate them + // when they are complex. + // + // We could vectorize this but it is so fast in comparison to heevd that + // it doesn't really matter. + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + *vectors = std::conj(*vectors); + vectors++; + } + } + } + } +}; + template void eigh_impl( array& vectors, @@ -19,8 +146,10 @@ void eigh_impl( const std::string& uplo, bool compute_eigenvectors, Stream stream) { + using R = typename EighWork::R; + auto vec_ptr = vectors.data(); - auto eig_ptr = values.data(); + auto eig_ptr = values.data(); char jobz = compute_eigenvectors ? 'V' : 'N'; auto& encoder = cpu::get_command_encoder(stream); @@ -33,49 +162,17 @@ void eigh_impl( N = vectors.shape(-1), size = vectors.size()]() mutable { // Work query - int lwork = -1; - int liwork = -1; - int info; - { - T work; - int iwork; - syevd( - &jobz, - &uplo, - &N, - nullptr, - &N, - nullptr, - &work, - &lwork, - &iwork, - &liwork, - &info); - lwork = static_cast(work); - liwork = iwork; - } + EighWork work(jobz, uplo, N); - auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; - auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)}; + // Work loop for (size_t i = 0; i < size / (N * N); ++i) { - syevd( - &jobz, - &uplo, - &N, - vec_ptr, - &N, - eig_ptr, - static_cast(work_buf.buffer.raw_ptr()), - &lwork, - static_cast(iwork_buf.buffer.raw_ptr()), - &liwork, - &info); + work.run(vec_ptr, eig_ptr); vec_ptr += N * N; eig_ptr += N; - if (info != 0) { + if (work.info != 0) { std::stringstream msg; msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " - << info; + << work.info; throw std::runtime_error(msg.str()); } } @@ -131,6 +228,10 @@ void Eigh::eval_cpu( eigh_impl( vectors, values, uplo_, compute_eigenvectors_, stream()); break; + case complex64: + eigh_impl>( + vectors, values, uplo_, compute_eigenvectors_, stream()); + break; default: throw std::runtime_error( "[Eigh::eval_cpu] only supports float32 or float64."); diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index 411742d56..b242093ff 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -2,14 +2,14 @@ #pragma once -// Required for Visual Studio. -// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md -#ifdef _MSC_VER #include #define LAPACK_COMPLEX_CUSTOM #define lapack_complex_float std::complex #define lapack_complex_double std::complex -#endif +#define lapack_complex_float_real(z) ((z).real()) +#define lapack_complex_float_imag(z) ((z).imag()) +#define lapack_complex_double_real(z) ((z).real()) +#define lapack_complex_double_imag(z) ((z).imag()) #ifdef MLX_USE_ACCELERATE #include @@ -32,7 +32,7 @@ #endif -#define INSTANTIATE_LAPACK_TYPES(FUNC) \ +#define INSTANTIATE_LAPACK_REAL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ @@ -42,12 +42,24 @@ } \ } -INSTANTIATE_LAPACK_TYPES(geqrf) -INSTANTIATE_LAPACK_TYPES(orgqr) -INSTANTIATE_LAPACK_TYPES(syevd) -INSTANTIATE_LAPACK_TYPES(geev) -INSTANTIATE_LAPACK_TYPES(potrf) -INSTANTIATE_LAPACK_TYPES(gesvdx) -INSTANTIATE_LAPACK_TYPES(getrf) -INSTANTIATE_LAPACK_TYPES(getri) -INSTANTIATE_LAPACK_TYPES(trtri) +INSTANTIATE_LAPACK_REAL(geqrf) +INSTANTIATE_LAPACK_REAL(orgqr) +INSTANTIATE_LAPACK_REAL(syevd) +INSTANTIATE_LAPACK_REAL(geev) +INSTANTIATE_LAPACK_REAL(potrf) +INSTANTIATE_LAPACK_REAL(gesvdx) +INSTANTIATE_LAPACK_REAL(getrf) +INSTANTIATE_LAPACK_REAL(getri) +INSTANTIATE_LAPACK_REAL(trtri) + +#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_COMPLEX(heevd) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e0f4ec2e6..144f9a880 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) { } } +void check_float_or_complex(Dtype dtype, const std::string& prefix) { + if (dtype != float32 && dtype != float64 && dtype != complex64) { + std::ostringstream msg; + msg << prefix << " Arrays must have type float32, float64 or complex64. " + << "Received array with type " << dtype << "."; + throw std::invalid_argument(msg.str()); + } +} + Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } @@ -493,7 +502,7 @@ void validate_eig( const StreamOrDevice& stream, const std::string fname) { check_cpu_stream(stream, fname); - check_float(a.dtype(), fname); + check_float_or_complex(a.dtype(), fname); if (a.ndim() < 2) { std::ostringstream msg; @@ -513,9 +522,10 @@ array eigvalsh( StreamOrDevice s /* = {} */) { validate_eig(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); return array( std::move(out_shape), - a.dtype(), + eigval_type, std::make_shared(to_stream(s), UPLO, false), {a}); } @@ -525,9 +535,10 @@ std::pair eigh( std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { validate_eig(a, s, "[linalg::eigh]"); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, - {a.dtype(), a.dtype()}, + {eigval_type, a.dtype()}, std::make_shared(to_stream(s), UPLO, true), {a}); return std::make_pair(out[0], out[1]); diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index f65da1ff7..f5eeda837 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -423,6 +423,13 @@ class TestLinalg(mlx_tests.MLXTestCase): A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2 check_eigs_and_vecs(A_np) + # Test with complex inputs + A_np = ( + np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1) + ) + A_np = A_np + A_np.T.conj() + check_eigs_and_vecs(A_np) + # Test error cases with self.assertRaises(ValueError): mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array