| @@ -660,7 +660,7 @@ table NetOutput { | |||||
| } | } | ||||
| table MatMul { | table MatMul { | ||||
| broadcast : bool = false; | |||||
| broadcast : bool = false; // DEPRECATED | |||||
| transposeA : bool = false; | transposeA : bool = false; | ||||
| transposeB : bool = false; | transposeB : bool = false; | ||||
| } | } | ||||
| @@ -189,7 +189,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||||
| attr->channelMultiplier = channel_mutiplier; | attr->channelMultiplier = channel_mutiplier; | ||||
| MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo); | MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo); | ||||
| auto input_node = inputs[kAnfPopulaterInputNumOne]; | |||||
| auto input_node = inputs.at(kAnfPopulaterInputNumOne); | |||||
| MS_ASSERT(input_node != nullptr); | MS_ASSERT(input_node != nullptr); | ||||
| if (input_node->isa<Parameter>()) { | if (input_node->isa<Parameter>()) { | ||||
| auto param_node = input_node->cast<ParameterPtr>(); | auto param_node = input_node->cast<ParameterPtr>(); | ||||
| @@ -201,7 +201,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||||
| MS_ASSERT(abstractTensor != nullptr); | MS_ASSERT(abstractTensor != nullptr); | ||||
| if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | ||||
| auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | ||||
| attr->channelIn = dims[kAnfPopulaterInputNumOne]; | |||||
| attr->channelIn = dims.at(kAnfPopulaterInputNumOne); | |||||
| } | } | ||||
| } | } | ||||
| } else if (input_node->isa<CNode>()) { | } else if (input_node->isa<CNode>()) { | ||||
| @@ -128,7 +128,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||||
| attr->channelMultiplier = channel_multiplier; | attr->channelMultiplier = channel_multiplier; | ||||
| MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo); | MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo); | ||||
| auto inputNode = inputs[kAnfPopulaterInputNumOne]; | |||||
| auto inputNode = inputs.at(kAnfPopulaterInputNumOne); | |||||
| MS_ASSERT(inputNode != nullptr); | MS_ASSERT(inputNode != nullptr); | ||||
| if (inputNode->isa<Parameter>()) { | if (inputNode->isa<Parameter>()) { | ||||
| auto paramNode = inputNode->cast<ParameterPtr>(); | auto paramNode = inputNode->cast<ParameterPtr>(); | ||||
| @@ -139,7 +139,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||||
| MS_ASSERT(abstractTensor != nullptr); | MS_ASSERT(abstractTensor != nullptr); | ||||
| if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | ||||
| auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | ||||
| attr->channelIn = dims[kAnfPopulaterInputNumOne]; | |||||
| attr->channelIn = dims.at(kAnfPopulaterInputNumOne); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -42,9 +42,6 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| float beta = 1.0f; | float beta = 1.0f; | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "broadcast") { | |||||
| attr->broadcast = static_cast<bool>(onnx_node_attr.i()); | |||||
| } | |||||
| if (attribute_name == "transA") { | if (attribute_name == "transA") { | ||||
| attr->transposeA = static_cast<bool>(onnx_node_attr.i()); | attr->transposeA = static_cast<bool>(onnx_node_attr.i()); | ||||
| } else if (attribute_name == "transB") { | } else if (attribute_name == "transB") { | ||||
| @@ -199,7 +199,6 @@ STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr, | |||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->broadcast = false; | |||||
| attr->transposeA = false; | attr->transposeA = false; | ||||
| attr->transposeB = false; | attr->transposeB = false; | ||||
| op->primitive->value.type = schema::PrimitiveType_MatMul; | op->primitive->value.type = schema::PrimitiveType_MatMul; | ||||
| @@ -36,7 +36,6 @@ PrimitiveC *TfliteMatMulParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions(); | const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions(); | ||||
| attr->transposeA = tflite_attr->adj_x; | attr->transposeA = tflite_attr->adj_x; | ||||
| attr->transposeB = tflite_attr->adj_y; | attr->transposeB = tflite_attr->adj_y; | ||||
| attr->broadcast = false; | |||||
| primitive->value.type = schema::PrimitiveType_MatMul; | primitive->value.type = schema::PrimitiveType_MatMul; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||