mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix reduce sum/prod overflow (#2477)
This commit is contained in:
		| @@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") { | ||||
|     CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1); | ||||
|   } | ||||
|  | ||||
|   // sum and prod overflow | ||||
|   { | ||||
|     auto a = full({256, 2, 2}, 1u, uint8); | ||||
|     CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4); | ||||
|     CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1); | ||||
|  | ||||
|     a = full({65535, 2, 2}, 1u, uint16); | ||||
|     CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4); | ||||
|     CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1); | ||||
|   } | ||||
| } | ||||
|  | ||||
| TEST_CASE("test gpu reduce with axes") { | ||||
|   // reducing only some axes and irregular layouts | ||||
|   { | ||||
|     array a(1.0f); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Abe Leininger
					Abe Leininger