mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix ternary for large arrays (#1359)
* fix ternary for large arrays * fix
This commit is contained in:
		| @@ -12,6 +12,7 @@ namespace { | ||||
| // TODO: Add support for more combinations of input types. | ||||
| enum class TernaryOpType { | ||||
|   ScalarScalarScalar, | ||||
|   VectorVectorVector, | ||||
|   General, | ||||
| }; | ||||
|  | ||||
| @@ -20,6 +21,12 @@ get_ternary_op_type(const array& a, const array& b, const array& c) { | ||||
|   TernaryOpType topt; | ||||
|   if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { | ||||
|     topt = TernaryOpType::ScalarScalarScalar; | ||||
|   } else if ( | ||||
|       (a.flags().row_contiguous && b.flags().row_contiguous && | ||||
|        c.flags().row_contiguous) || | ||||
|       (a.flags().col_contiguous && b.flags().col_contiguous && | ||||
|        c.flags().col_contiguous)) { | ||||
|     topt = TernaryOpType::VectorVectorVector; | ||||
|   } else { | ||||
|     topt = TernaryOpType::General; | ||||
|   } | ||||
| @@ -33,11 +40,32 @@ void set_ternary_op_output_data( | ||||
|     array& out, | ||||
|     TernaryOpType topt, | ||||
|     bool donate_with_move = false) { | ||||
|   auto maybe_donate = [&out, donate_with_move](const array& x) { | ||||
|     if (x.is_donatable() && x.itemsize() == out.itemsize()) { | ||||
|       if (donate_with_move) { | ||||
|         out.move_shared_buffer(x); | ||||
|       } else { | ||||
|         out.copy_shared_buffer(x); | ||||
|       } | ||||
|       return true; | ||||
|     } | ||||
|     return false; | ||||
|   }; | ||||
|  | ||||
|   switch (topt) { | ||||
|     case TernaryOpType::ScalarScalarScalar: | ||||
|       out.set_data( | ||||
|           allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); | ||||
|       break; | ||||
|     case TernaryOpType::VectorVectorVector: | ||||
|       if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { | ||||
|         out.set_data( | ||||
|             allocator::malloc_or_wait(out.itemsize() * b.data_size()), | ||||
|             b.data_size(), | ||||
|             b.strides(), | ||||
|             b.flags()); | ||||
|       } | ||||
|       break; | ||||
|     case TernaryOpType::General: | ||||
|       out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|       break; | ||||
|   | ||||
| @@ -56,9 +56,12 @@ void ternary_op_gpu_inplace( | ||||
|  | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|   compute_encoder.set_input_array(a, 0); | ||||
|   compute_encoder.set_input_array(b, 1); | ||||
|   compute_encoder.set_input_array(c, 2); | ||||
|   bool donate_a = a.data_shared_ptr() == nullptr; | ||||
|   bool donate_b = b.data_shared_ptr() == nullptr; | ||||
|   bool donate_c = c.data_shared_ptr() == nullptr; | ||||
|   compute_encoder.set_input_array(donate_a ? out : a, 0); | ||||
|   compute_encoder.set_input_array(donate_b ? out : b, 1); | ||||
|   compute_encoder.set_input_array(donate_c ? out : c, 2); | ||||
|   compute_encoder.set_output_array(out, 3); | ||||
|  | ||||
|   if (topt == TernaryOpType::General) { | ||||
| @@ -91,9 +94,10 @@ void ternary_op_gpu_inplace( | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     // Launch a 1D grid of threads | ||||
|     // Launch a 1D or 2D grid of threads | ||||
|     size_t nthreads = out.data_size(); | ||||
|     MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|     MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) | ||||
|                                  : MTL::Size(nthreads, 1, 1); | ||||
|     NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     if (thread_group_size > nthreads) { | ||||
|       thread_group_size = nthreads; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun