Browse Source

!7775 [MS][LITE]Fix softmax parser, fix mul bug, fix lrn bug

Merge pull request !7775 from gongdaguo/fix_softmax
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8aaa20f918
9 changed files with 50 additions and 9 deletions
  1. +1
    -1
      mindspore/lite/schema/ops.fbs
  2. +21
    -6
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
  3. +5
    -0
      mindspore/lite/test/models_onnx.cfg
  4. +1
    -0
      mindspore/lite/tools/common/node_util.cc
  5. +1
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  6. +2
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
  7. +9
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h
  8. +1
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc
  9. +9
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc

+ 1
- 1
mindspore/lite/schema/ops.fbs View File

@@ -172,7 +172,7 @@ table Concat {
}

table SoftMax {
axis: int;
axis: int = -1;
}

table Activation {


+ 21
- 6
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc View File

@@ -73,12 +73,27 @@ int ArithmeticCPUKernel::ReSize() {
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()),
in_tensors_[0]->shape().size() * sizeof(int));
memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()),
in_tensors_[1]->shape().size() * sizeof(int));
memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()),
out_tensors_[0]->shape().size() * sizeof(int));
for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) {
if (arithmeticParameter_->in_shape0_[i] == -1) {
memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()),
in_tensors_[0]->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) {
if (arithmeticParameter_->in_shape1_[i] == -1) {
memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()),
in_tensors_[1]->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) {
if (arithmeticParameter_->out_shape_[i] == -1) {
memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()),
out_tensors_[0]->shape().size() * sizeof(int));
break;
}
}

if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
switch (arithmeticParameter_->op_parameter_.type_) {


+ 5
- 0
mindspore/lite/test/models_onnx.cfg View File

@@ -2,6 +2,11 @@ mtk_detect-mbv2-shortcut-400-400-simplified.onnx
mtk_emotions-d2012-75.8%.onnx
mtk_face_features_v3.onnx
emotion-ferplus-8.onnx
#rcnn-ilsvrc13-9.onnx
efficientnet-lite4-11.onnx
mobilenetv2-7.onnx
shufflenet-v2-10.onnx
squeezenet1.1-7.onnx
ml_face_3d.onnx
gts_version-RFB-320_simplified.onnx
mnist-8.onnx


+ 1
- 0
mindspore/lite/tools/common/node_util.cc View File

@@ -42,6 +42,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_LocalResponseNormalization,
schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm,


+ 1
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -608,6 +608,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
OnnxNodeParser::set_opset_version(onnx_model.opset_import().Get(0).version());
const onnx::GraphProto &onnx_graph = onnx_model.graph();
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();



+ 2
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc View File

@@ -20,6 +20,8 @@

namespace mindspore {
namespace lite {
int OnnxNodeParser::opset_version_ = 0;

schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) {
if (onnx_node_attr.s() == "NOTSET") {
return schema::PadMode_NOTSET;


+ 9
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h View File

@@ -37,12 +37,21 @@ class OnnxNodeParser {

STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type);

static STATUS set_opset_version(int version) {
opset_version_ = version;
return RET_OK;
}
static int opset_version() { return opset_version_; }

protected:
schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);

void Split(const std::string &src_str, std::vector<std::string> *dst_str, const std::string &chr);

const std::string &name;

private:
static int opset_version_;
};
} // namespace lite
} // namespace mindspore


+ 1
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc View File

@@ -94,7 +94,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}
}
if (attribute_name == "ceil_mode") {
if (onnx_node_attr.f() == 0) {
if (onnx_node_attr.i() == 0) {
attr->roundMode = schema::RoundMode_FLOOR;
} else {
attr->roundMode = schema::RoundMode_CEIL;


+ 9
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc View File

@@ -38,13 +38,21 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
return RET_NULL_PTR;
}

bool axis_is_def = true;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
axis_is_def = false;
}
}
if (axis_is_def) {
if (OnnxNodeParser::opset_version() >= 13) {
attr->axis = -1;
} else {
attr->axis = 1;
}
}

op->primitive->value.type = schema::PrimitiveType_SoftMax;
op->primitive->value.value = attr.release();
return RET_OK;


Loading…
Cancel
Save