Up to 10x faster scatter. (#709)

* Faster scatter.

Add specialization for 1-d index tensors.

* Address review comments.

- Check for row contiguity of index, update tensors
  instead of checking strides.
- Add support for 1d specialization with col contiguous update
  tensor, along with a test.

* Nit1

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Nit2

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Vijay Krish
2024-02-21 11:09:30 -08:00
committed by GitHub
parent 7dcdd88e27
commit 972d9a3aea
4 changed files with 244 additions and 83 deletions

View File

@@ -1858,6 +1858,14 @@ TEST_CASE("test scatter") {
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
// Array scatters with col contiguous updates
in = zeros({4, 4}, float32);
inds = array({0, 1, 2, 3});
updates = transpose(reshape(arange(16, float32), {4, 1, 4}));
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4})))
.item<bool>());
// Irregular strided index and reduce collision test
in = zeros({10}, float32);
inds = broadcast_to(array(3), {10});
@@ -1877,10 +1885,10 @@ TEST_CASE("test scatter") {
// Irregularly strided updates test
in = ones({3, 3});
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
updates = broadcast_to(array({2, 2, 2}), {1, 3, 3});
inds = array({0});
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, zeros({3, 3})).item<bool>());
CHECK(array_equal(out, ones({3, 3}) * 2).item<bool>());
// Along different axis
in = zeros({2, 3});