mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix extension (#1740)
This commit is contained in:
parent
491fa95b1f
commit
a64a8dfe45
@ -58,9 +58,9 @@ mx::array axpby(
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return mx::array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* const mx::Shape& shape = */ out_shape,
|
||||
/* mx::Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<mx::Primitive> primitive = */
|
||||
/* std::shared_ptr<mx::Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
@ -279,7 +279,7 @@ void Axpby::eval_gpu(
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_vector_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user