mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CPU LU factorization and linear solvers (#1451)
* linalg solve backend * nits * more nits + fix * luf primitive and lu, solve, and solve_triangular backends * changes / nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2329,7 +2329,6 @@ class Eigh : public Primitive {
|
||||
: Primitive(stream),
|
||||
uplo_(std::move(uplo)),
|
||||
compute_eigenvectors_(compute_eigenvectors) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -2350,4 +2349,16 @@ class Eigh : public Primitive {
|
||||
bool compute_eigenvectors_;
|
||||
};
|
||||
|
||||
/* LU Factorization primitive. */
|
||||
class LUF : public Primitive {
|
||||
public:
|
||||
explicit LUF(Stream stream) : Primitive(stream) {}
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LUF)
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user