mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
Up to 10x faster scatter. (#709)
* Faster scatter. Add specialization for 1-d index tensors. * Address review comments. - Check for row contiguity of index, update tensors instead of checking strides. - Add support for 1d specialization with col contiguous update tensor, along with a test. * Nit1 Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Nit2 Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -1858,6 +1858,14 @@ TEST_CASE("test scatter") {
|
||||
out = scatter(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
|
||||
|
||||
// Array scatters with col contiguous updates
|
||||
in = zeros({4, 4}, float32);
|
||||
inds = array({0, 1, 2, 3});
|
||||
updates = transpose(reshape(arange(16, float32), {4, 1, 4}));
|
||||
out = scatter(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4})))
|
||||
.item<bool>());
|
||||
|
||||
// Irregular strided index and reduce collision test
|
||||
in = zeros({10}, float32);
|
||||
inds = broadcast_to(array(3), {10});
|
||||
@@ -1877,10 +1885,10 @@ TEST_CASE("test scatter") {
|
||||
|
||||
// Irregularly strided updates test
|
||||
in = ones({3, 3});
|
||||
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
|
||||
updates = broadcast_to(array({2, 2, 2}), {1, 3, 3});
|
||||
inds = array({0});
|
||||
out = scatter(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, zeros({3, 3})).item<bool>());
|
||||
CHECK(array_equal(out, ones({3, 3}) * 2).item<bool>());
|
||||
|
||||
// Along different axis
|
||||
in = zeros({2, 3});
|
||||
|
Reference in New Issue
Block a user