Browse Source

fix code review

tags/v1.1.0
sunsuodong 5 years ago
parent
commit
3d7e9b0c79
6 changed files with 5 additions and 10 deletions
  1. +1
    -1
      mindspore/lite/schema/ops.fbs
  2. +2
    -2
      mindspore/lite/src/ops/conv2d.cc
  3. +2
    -2
      mindspore/lite/src/ops/depthwise_conv2d.cc
  4. +0
    -3
      mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc
  5. +0
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc
  6. +0
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_matmul_parser.cc

+ 1
- 1
mindspore/lite/schema/ops.fbs View File

@@ -660,7 +660,7 @@ table NetOutput {
}

table MatMul {
broadcast : bool = false;
broadcast : bool = false; // DEPRECATED
transposeA : bool = false;
transposeB : bool = false;
}


+ 2
- 2
mindspore/lite/src/ops/conv2d.cc View File

@@ -189,7 +189,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
attr->channelMultiplier = channel_mutiplier;

MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
auto input_node = inputs[kAnfPopulaterInputNumOne];
auto input_node = inputs.at(kAnfPopulaterInputNumOne);
MS_ASSERT(input_node != nullptr);
if (input_node->isa<Parameter>()) {
auto param_node = input_node->cast<ParameterPtr>();
@@ -201,7 +201,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
MS_ASSERT(abstractTensor != nullptr);
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
attr->channelIn = dims[kAnfPopulaterInputNumOne];
attr->channelIn = dims.at(kAnfPopulaterInputNumOne);
}
}
} else if (input_node->isa<CNode>()) {


+ 2
- 2
mindspore/lite/src/ops/depthwise_conv2d.cc View File

@@ -128,7 +128,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
attr->channelMultiplier = channel_multiplier;

MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
auto inputNode = inputs[kAnfPopulaterInputNumOne];
auto inputNode = inputs.at(kAnfPopulaterInputNumOne);
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>();
@@ -139,7 +139,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
MS_ASSERT(abstractTensor != nullptr);
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
attr->channelIn = dims[kAnfPopulaterInputNumOne];
attr->channelIn = dims.at(kAnfPopulaterInputNumOne);
}
}
}


+ 0
- 3
mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc View File

@@ -42,9 +42,6 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
float beta = 1.0f;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "broadcast") {
attr->broadcast = static_cast<bool>(onnx_node_attr.i());
}
if (attribute_name == "transA") {
attr->transposeA = static_cast<bool>(onnx_node_attr.i());
} else if (attribute_name == "transB") {


+ 0
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc View File

@@ -199,7 +199,6 @@ STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr,
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->broadcast = false;
attr->transposeA = false;
attr->transposeB = false;
op->primitive->value.type = schema::PrimitiveType_MatMul;


+ 0
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_matmul_parser.cc View File

@@ -36,7 +36,6 @@ PrimitiveC *TfliteMatMulParser::ParseLitePrimitive(const std::unique_ptr<tflite:
const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions();
attr->transposeA = tflite_attr->adj_x;
attr->transposeB = tflite_attr->adj_y;
attr->broadcast = false;
primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();



Loading…
Cancel
Save