| @@ -48,7 +48,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (prim.GetAttr("axis") == nullptr) { | if (prim.GetAttr("axis") == nullptr) { | ||||
| MS_LOG(WARNING) << "get axis failed"; | |||||
| MS_LOG(INFO) << "BiasAdd's attr axis is set to default"; | |||||
| attr->axis = {1}; | attr->axis = {1}; | ||||
| } else { | } else { | ||||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | ||||
| @@ -84,6 +84,8 @@ int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||||
| out_shape[2] = hidden_size; | out_shape[2] = hidden_size; | ||||
| if (GetBidirection()) { | if (GetBidirection()) { | ||||
| out_shape.insert(out_shape.begin() + 1, 2); | out_shape.insert(out_shape.begin() + 1, 2); | ||||
| } else { | |||||
| out_shape.insert(out_shape.begin() + 1, 1); | |||||
| } | } | ||||
| output->set_shape(out_shape); | output->set_shape(out_shape); | ||||
| // set hidden state, cell state | // set hidden state, cell state | ||||
| @@ -56,19 +56,19 @@ int Power::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu | |||||
| } | } | ||||
| if (prim.GetAttr("scale") == nullptr) { | if (prim.GetAttr("scale") == nullptr) { | ||||
| MS_LOG(WARNING) << "get scale failed"; | |||||
| MS_LOG(INFO) << "Power's attr scale is set to default"; | |||||
| attr->scale = 1.0f; | attr->scale = 1.0f; | ||||
| } else { | } else { | ||||
| attr->scale = GetValue<float>(prim.GetAttr("scale")); | attr->scale = GetValue<float>(prim.GetAttr("scale")); | ||||
| } | } | ||||
| if (prim.GetAttr("power") == nullptr) { | if (prim.GetAttr("power") == nullptr) { | ||||
| MS_LOG(WARNING) << "get power failed"; | |||||
| MS_LOG(INFO) << "Power's attr power is set to default"; | |||||
| attr->power = 1.0f; | attr->power = 1.0f; | ||||
| } else { | } else { | ||||
| attr->power = GetValue<float>(prim.GetAttr("power")); | attr->power = GetValue<float>(prim.GetAttr("power")); | ||||
| } | } | ||||
| if (prim.GetAttr("shift") == nullptr) { | if (prim.GetAttr("shift") == nullptr) { | ||||
| MS_LOG(WARNING) << "get shift failed"; | |||||
| MS_LOG(INFO) << "Power's attr shift is set to default"; | |||||
| attr->shift = 0; | attr->shift = 0; | ||||
| } else { | } else { | ||||
| attr->shift = GetValue<float>(prim.GetAttr("shift")); | attr->shift = GetValue<float>(prim.GetAttr("shift")); | ||||
| @@ -47,7 +47,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (prim.GetAttr("axis") == nullptr) { | if (prim.GetAttr("axis") == nullptr) { | ||||
| MS_LOG(WARNING) << "get axis failed"; | |||||
| MS_LOG(INFO) << "Squeeze's attr xis is set to default"; | |||||
| attr->axis = {0}; | attr->axis = {0}; | ||||
| } else { | } else { | ||||
| int axis = GetValue<int>(prim.GetAttr("axis")); | int axis = GetValue<int>(prim.GetAttr("axis")); | ||||
| @@ -53,7 +53,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (prim.GetAttr("dims") == nullptr) { | if (prim.GetAttr("dims") == nullptr) { | ||||
| MS_LOG(WARNING) << "get dims failed"; | |||||
| MS_LOG(INFO) << "Tile's attr dims is set to default"; | |||||
| attr->dims = {1}; | attr->dims = {1}; | ||||
| } else { | } else { | ||||
| attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims")); | attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims")); | ||||
| @@ -124,6 +124,7 @@ int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o | |||||
| if (!GetInferFlag()) { | if (!GetInferFlag()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum); | |||||
| MS_ASSERT(outputs_.size() == kSingleNum); | MS_ASSERT(outputs_.size() == kSingleNum); | ||||
| int conjugate = GetConjugate(); | int conjugate = GetConjugate(); | ||||
| @@ -116,6 +116,7 @@ int TransposeFp32Run(void *cdata, int task_id) { | |||||
| } | } | ||||
| int TransposeCPUKernel::Run() { | int TransposeCPUKernel::Run() { | ||||
| MS_ASSERT(in_tensors_.size() == 1 || in_tensors_.size() == 2); | |||||
| MS_ASSERT(out_tensors_.size() == 1); | MS_ASSERT(out_tensors_.size() == 1); | ||||
| auto &in_tensor = in_tensors_.front(); | auto &in_tensor = in_tensors_.front(); | ||||
| auto &out_tensor = out_tensors_.front(); | auto &out_tensor = out_tensors_.front(); | ||||
| @@ -15,7 +15,9 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" | #include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" | ||||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include <numeric> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -130,21 +132,21 @@ STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||||
| } | } | ||||
| const auto &onnx_pow_power = onnx_node.input(1); | const auto &onnx_pow_power = onnx_node.input(1); | ||||
| auto nodeIter = | |||||
| std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | |||||
| [onnx_pow_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_pow_power; }); | |||||
| if (nodeIter == onnx_graph.node().end()) { | |||||
| int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_pow_power); | |||||
| if (index == -1) { | |||||
| MS_LOG(ERROR) << "can not find node: " << onnx_pow_power; | MS_LOG(ERROR) << "can not find node: " << onnx_pow_power; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| const float *pW = nullptr; | |||||
| for (const auto &attrPower : nodeIter->attribute()) { | |||||
| if (attrPower.name() == "value") { | |||||
| const auto &t = attrPower.t(); | |||||
| pW = reinterpret_cast<const float *>(t.raw_data().data()); | |||||
| } | |||||
| auto pow_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; | |||||
| if (std::accumulate(pow_attr->dims.begin(), pow_attr->dims.end(), 1, std::multiplies<int>()) != 1) { | |||||
| MS_LOG(ERROR) << "the exponent element num is bigger than 1, which don't support now."; | |||||
| return RET_NOT_SUPPORT; | |||||
| } | } | ||||
| attr->power = *pW; | |||||
| if (pow_attr->data.data() == nullptr) { | |||||
| MS_LOG(ERROR) << "power's attr pow can't be obtained."; | |||||
| return RET_INVALID_OP_ATTR; | |||||
| } | |||||
| attr->power = *reinterpret_cast<float *>(pow_attr->data.data()); | |||||
| attr->scale = 1.0f; | attr->scale = 1.0f; | ||||
| attr->shift = 0.0f; | attr->shift = 0.0f; | ||||
| op->primitive->value.type = schema::PrimitiveType_Power; | op->primitive->value.type = schema::PrimitiveType_Power; | ||||
| @@ -27,7 +27,7 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC | |||||
| } | } | ||||
| std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>(); | std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>(); | ||||
| if (deDepthwiseConv2DParam == nullptr) { | if (deDepthwiseConv2DParam == nullptr) { | ||||
| MS_LOG(WARNING) << "new op failed"; | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| deDepthwiseConv2DParam->format = attr->format; | deDepthwiseConv2DParam->format = attr->format; | ||||
| @@ -374,8 +374,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { | |||||
| } else if (utils::isa<ValueNodePtr>(n)) { | } else if (utils::isa<ValueNodePtr>(n)) { | ||||
| value_node = utils::cast<ValueNodePtr>(n); | value_node = utils::cast<ValueNodePtr>(n); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "only value node or cnode has type"; | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); | |||||
| MS_LOG(INFO) << "only value node or cnode has type"; | |||||
| return schema::PrimitiveType_NONE; | return schema::PrimitiveType_NONE; | ||||
| } | } | ||||
| if (value_node == nullptr) { | if (value_node == nullptr) { | ||||