Browse Source

!6205 [MSLITE]fix reshape avgpooling and convbiasaddfusion bug

Merge pull request !6205 from zhaodezan/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
94ca69bb03
3 changed files with 13 additions and 1 deletions
  1. +2
    -0
      mindspore/lite/src/ops/primitive_c.cc
  2. +1
    -1
      mindspore/lite/src/ops/reshape.cc
  3. +10
    -0
      mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc

+ 2
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -399,6 +399,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<SoftMax>(prim, inputs, quantType);
} else if (op_type == "StridedSlice") {
return NewPrimitiveC<StridedSlice>(prim, inputs, quantType);
} else if (op_type == "AvgPool") {
return NewPrimitiveC<Pooling>(prim, inputs, quantType);


#ifdef SUPPORT_TRAIN


+ 1
- 1
mindspore/lite/src/ops/reshape.cc View File

@@ -169,7 +169,7 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
if (shape_tensor->MutableData() == nullptr) {
if (shape_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}


+ 10
- 0
mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc View File

@@ -24,6 +24,7 @@
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "src/ops/add.h"

namespace mindspore::opt {
namespace {
@@ -155,6 +156,15 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
auto add_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(add_node);
CheckInputSize(add_node, kAddInputsLength);
if (GetCNodeType(add_node) == schema::PrimitiveType_Add) {
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(add_node->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Add>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Add>>(primitive_c);
MS_ASSERT(primc != nullptr);
if (primc->GetActivationType() != schema::ActivationType_NO_ACTIVATION) {
return add_node;
}
}

AnfNodePtr conv_node_anf = add_node->input(1);
CheckIfAnfNodeIsNull(conv_node_anf);


Loading…
Cancel
Save