|
|
|
@@ -120,8 +120,8 @@ class AddByZero : public AnfVisitor { |
|
|
|
AnfNodePtr x_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimTensorAdd, {PrimZerosLikeTensor, Y}, X}, |
|
|
|
// {prim::kPrimTensorAdd, X, {PrimZerosLikeTensor, Y}} |
|
|
|
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, |
|
|
|
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} |
|
|
|
class TensorAddByZero : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
@@ -135,7 +135,7 @@ class TensorAddByZero : public AnfVisitor { |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (IsPrimitive(node, prim::kPrimZerosLikeTensor)) { |
|
|
|
if (IsPrimitive(node, prim::kPrimZerosLike)) { |
|
|
|
is_zero_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -153,7 +153,7 @@ class TensorAddByZero : public AnfVisitor { |
|
|
|
AnfNodePtr x_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
// {PrimMomentum, {PrimZerosLikeTensor, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} |
|
|
|
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} |
|
|
|
class OptUpdateZeroTensor : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
@@ -163,13 +163,13 @@ class OptUpdateZeroTensor : public AnfVisitor { |
|
|
|
|
|
|
|
// {PrimMomentum, {...}, Y, Z, Xs} |
|
|
|
auto &inputs = node->cast<CNodePtr>()->inputs(); |
|
|
|
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLikeTensor)) { |
|
|
|
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto y = inputs[2]; |
|
|
|
auto z = inputs[3]; |
|
|
|
|
|
|
|
// {PrimZerosLikeTensor, X} |
|
|
|
// {kPrimZerosLike, X} |
|
|
|
if (inputs[1]->cast<CNodePtr>()->size() != 2) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|