mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Gather mm new kernel and small refactoring (#2040)
This commit is contained in:
committed by
GitHub
parent
e9e268336b
commit
99eefd2ec0
13
mlx/ops.cpp
13
mlx/ops.cpp
@@ -4499,6 +4499,7 @@ array gather_mm(
|
||||
array b,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
bool sorted_indices /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// If no indices, fall back to full matmul
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
@@ -4574,12 +4575,18 @@ array gather_mm(
|
||||
out_shape.push_back(M);
|
||||
out_shape.push_back(N);
|
||||
|
||||
// Caculate array
|
||||
// Make the output array
|
||||
auto out = array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherMM>(to_stream(s)),
|
||||
{a, b, lhs_indices, rhs_indices});
|
||||
std::make_shared<GatherMM>(
|
||||
to_stream(s),
|
||||
sorted_indices && !rhs_indices_,
|
||||
sorted_indices && !lhs_indices_),
|
||||
{std::move(a),
|
||||
std::move(b),
|
||||
std::move(lhs_indices),
|
||||
std::move(rhs_indices)});
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
std::vector<int> axes;
|
||||
|
||||
Reference in New Issue
Block a user