|
|
|
@@ -30,6 +30,7 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace irpass { |
|
|
|
|
|
|
|
class SpecialOpEliminater { |
|
|
|
public: |
|
|
|
SpecialOpEliminater() |
|
|
|
@@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor { |
|
|
|
if (y_ == nullptr || node->func_graph() == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) { |
|
|
|
auto fg = node->func_graph(); |
|
|
|
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); |
|
|
|
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); |
|
|
|
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); |
|
|
|
} |
|
|
|
|
|
|
|
abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>(); |
|
|
|
|
|
|
|
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); |
|
|
|
std::vector<int> tensor_shape = tensor_abstract->shape()->shape(); |
|
|
|
|
|
|
|
tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); |
|
|
|
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); |
|
|
|
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); |
|
|
|
std::memset(data, 0, mem_size); |
|
|
|
|
|
|
|
auto fg = node->func_graph(); |
|
|
|
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); |
|
|
|
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); |
|
|
|
auto new_cnode = NewValueNode(new_tensor_ptr); |
|
|
|
new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); |
|
|
|
|
|
|
|
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { y_ = node; } |
|
|
|
|