mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	 44c1ce5e6a
			
		
	
	44c1ce5e6a
	
	
	
		
			
			* spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
		
			
				
	
	
		
			208 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			208 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright © 2023 Apple Inc.
 | |
| 
 | |
| #include <iostream>
 | |
| 
 | |
| #include "doctest/doctest.h"
 | |
| 
 | |
| #include "mlx/mlx.h"
 | |
| #include "mlx/primitives.h"
 | |
| 
 | |
| using namespace mlx::core;
 | |
| 
 | |
| void test_arg_reduce_small(
 | |
|     Device d,
 | |
|     const array& x,
 | |
|     ArgReduce::ReduceType r,
 | |
|     std::vector<int> out_shape,
 | |
|     int axis,
 | |
|     std::vector<int> expected_output) {
 | |
|   auto s = default_stream(d);
 | |
|   auto y =
 | |
|       array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
 | |
|   y.eval();
 | |
|   const uint32_t* ydata = y.data<uint32_t>();
 | |
|   for (int i = 0; i < y.size(); i++) {
 | |
|     CHECK_EQ(expected_output[i], ydata[i]);
 | |
|   }
 | |
| }
 | |
| 
 | |
| void test_arg_reduce_against_cpu(
 | |
|     const array& x,
 | |
|     ArgReduce::ReduceType r,
 | |
|     std::vector<int> out_shape,
 | |
|     int axis) {
 | |
|   auto y1 = array(
 | |
|       out_shape,
 | |
|       uint32,
 | |
|       std::make_unique<ArgReduce>(default_stream(Device::cpu), r, axis),
 | |
|       {x});
 | |
|   auto y2 = array(
 | |
|       out_shape,
 | |
|       uint32,
 | |
|       std::make_unique<ArgReduce>(default_stream(Device::gpu), r, axis),
 | |
|       {x});
 | |
|   y1.eval();
 | |
|   y2.eval();
 | |
|   CHECK(array_equal(y1, y2).item<bool>());
 | |
| }
 | |
| 
 | |
| TEST_CASE("test arg reduce small") {
 | |
|   auto x = array(
 | |
|       {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
 | |
|        0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
 | |
|       {2, 3, 4});
 | |
|   x.eval();
 | |
|   test_arg_reduce_small(
 | |
|       Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
 | |
|   test_arg_reduce_small(
 | |
|       Device::cpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});
 | |
|   test_arg_reduce_small(
 | |
|       Device::cpu,
 | |
|       x,
 | |
|       ArgReduce::ArgMin,
 | |
|       {3, 4},
 | |
|       0,
 | |
|       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
 | |
|   test_arg_reduce_small(
 | |
|       Device::cpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});
 | |
|   test_arg_reduce_small(
 | |
|       Device::cpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});
 | |
|   test_arg_reduce_small(
 | |
|       Device::cpu,
 | |
|       x,
 | |
|       ArgReduce::ArgMax,
 | |
|       {3, 4},
 | |
|       0,
 | |
|       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
 | |
| 
 | |
|   if (!metal::is_available()) {
 | |
|     INFO("Skipping arg reduction gpu tests");
 | |
|     return;
 | |
|   }
 | |
| 
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu,
 | |
|       x,
 | |
|       ArgReduce::ArgMin,
 | |
|       {3, 4},
 | |
|       0,
 | |
|       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu,
 | |
|       x,
 | |
|       ArgReduce::ArgMax,
 | |
|       {3, 4},
 | |
|       0,
 | |
|       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
 | |
| }
 | |
| 
 | |
| TEST_CASE("test arg reduce against cpu") {
 | |
|   if (!metal::is_available()) {
 | |
|     INFO("Skipping arg reduction gpu tests");
 | |
|     return;
 | |
|   }
 | |
| 
 | |
|   auto x = random::uniform(array(0.0), array(1.0), {127, 92, 55});
 | |
|   x.eval();
 | |
|   test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 92}, 2);
 | |
|   test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 55}, 1);
 | |
|   test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {92, 55}, 0);
 | |
|   test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 92}, 2);
 | |
|   test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 55}, 1);
 | |
|   test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {92, 55}, 0);
 | |
| 
 | |
|   auto y = random::uniform(array(0.0), array(1.0), {1234});
 | |
|   y.eval();
 | |
|   test_arg_reduce_against_cpu(y, ArgReduce::ArgMin, {}, 0);
 | |
|   test_arg_reduce_against_cpu(y, ArgReduce::ArgMax, {}, 0);
 | |
| }
 | |
| 
 | |
| void test_arg_reduce_small_bool(
 | |
|     Device d,
 | |
|     ArgReduce::ReduceType r,
 | |
|     std::vector<int> out_shape,
 | |
|     int axis,
 | |
|     std::vector<int> expected_output) {
 | |
|   auto s = default_stream(d);
 | |
|   auto x = array(
 | |
|       {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
 | |
|        0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
 | |
|       {2, 3, 4});
 | |
|   x.eval();
 | |
|   auto y =
 | |
|       array(out_shape, uint32, std::make_unique<ArgReduce>(s, r, axis), {x});
 | |
|   y.eval();
 | |
|   const uint32_t* ydata = y.data<uint32_t>();
 | |
|   for (int i = 0; i < y.size(); i++) {
 | |
|     CHECK_EQ(expected_output[i], ydata[i]);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST_CASE("test arg reduce bool") {
 | |
|   if (!metal::is_available()) {
 | |
|     INFO("Skipping arg reduction gpu tests");
 | |
|     return;
 | |
|   }
 | |
|   auto x = array(
 | |
|       {false, true,  true,  false, false, false, false, true,
 | |
|        true,  false, true,  true,  false, true,  true,  false,
 | |
|        false, false, false, true,  true,  false, true,  true},
 | |
|       {2, 3, 4});
 | |
|   x.eval();
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 0, 1, 0, 0, 1});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 0, 0, 1, 1, 0});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu,
 | |
|       x,
 | |
|       ArgReduce::ArgMin,
 | |
|       {3, 4},
 | |
|       0,
 | |
|       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {1, 3, 0, 1, 3, 0});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {2, 0, 0, 1, 2, 0, 0, 1});
 | |
|   test_arg_reduce_small(
 | |
|       Device::gpu,
 | |
|       x,
 | |
|       ArgReduce::ArgMax,
 | |
|       {3, 4},
 | |
|       0,
 | |
|       {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
 | |
| }
 | |
| 
 | |
| TEST_CASE("test arg reduce edge cases") {
 | |
|   auto a = argmin(array(1.0));
 | |
|   CHECK_EQ(a.item<uint32_t>(), 0);
 | |
|   auto b = argmax(array(1.0));
 | |
|   CHECK_EQ(b.item<uint32_t>(), 0);
 | |
|   CHECK_THROWS(argmin(array({})));
 | |
|   CHECK_THROWS(argmax(array({})));
 | |
| }
 | |
| 
 | |
| TEST_CASE("test arg reduce irregular strides") {
 | |
|   auto x = array(
 | |
|       {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
 | |
|        0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
 | |
|       {2, 3, 4});
 | |
|   x = transpose(x, {2, 0, 1});
 | |
|   x.eval();
 | |
|   test_arg_reduce_small(
 | |
|       Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});
 | |
| 
 | |
|   if (!metal::is_available()) {
 | |
|     INFO("Skipping arg reduction gpu tests");
 | |
|     return;
 | |
|   }
 | |
| }
 |