mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Remove the kernel arg from get_launch_args (#2437)
This commit is contained in:
		@@ -128,7 +128,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  encoder.set_output_array(out);
 | 
			
		||||
 | 
			
		||||
  auto kernel = mod.get_kernel(kernel_name);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(out, large);
 | 
			
		||||
  encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -229,7 +229,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  }
 | 
			
		||||
  encoder.set_output_array(out);
 | 
			
		||||
  auto kernel = mod.get_kernel(kernel_name);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(upd, large);
 | 
			
		||||
  encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -317,7 +317,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  }
 | 
			
		||||
  encoder.set_output_array(out);
 | 
			
		||||
  auto kernel = mod.get_kernel(kernel_name);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(idx, large);
 | 
			
		||||
  encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -421,7 +421,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  }
 | 
			
		||||
  encoder.set_output_array(out);
 | 
			
		||||
  auto kernel = mod.get_kernel(kernel_name);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
 | 
			
		||||
  auto [num_blocks, block_dims] = get_launch_args(idx, large);
 | 
			
		||||
  encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user