mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix copying scalars by adding fill_gpu (#1402)
* fix copying scalars by adding fill_gpu * Another copy scalar changed to fill --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -552,7 +552,7 @@ void winograd_conv_2D_gpu(
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
|
||||
fill_gpu(zero_arr, in_padded, s);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
// Pick input slice from padded
|
||||
@@ -571,7 +571,6 @@ void winograd_conv_2D_gpu(
|
||||
|
||||
copies_w.push_back(in_padded_slice);
|
||||
copies_w.push_back(in_padded);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ in_padded.shape(0),
|
||||
|
||||
Reference in New Issue
Block a user