mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
ExpandDims
primitive (#1687)
* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
@@ -211,6 +211,10 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out = out,
|
||||
@@ -381,6 +385,10 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
Reference in New Issue
Block a user