Browse Source

adjust log and fix lstm infershape

tags/v1.1.0
xuanyue 5 years ago
parent
commit
a3f5f60af0
10 changed files with 25 additions and 20 deletions
  1. +1
    -1
      mindspore/lite/src/ops/bias_add.cc
  2. +2
    -0
      mindspore/lite/src/ops/lstm.cc
  3. +3
    -3
      mindspore/lite/src/ops/power.cc
  4. +1
    -1
      mindspore/lite/src/ops/squeeze.cc
  5. +1
    -1
      mindspore/lite/src/ops/tile.cc
  6. +1
    -0
      mindspore/lite/src/ops/transpose.cc
  7. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc
  8. +13
    -11
      mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc
  9. +1
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc
  10. +1
    -2
      mindspore/lite/tools/optimizer/common/gllo_utils.cc

+ 1
- 1
mindspore/lite/src/ops/bias_add.cc View File

@@ -48,7 +48,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
return RET_ERROR;
}
if (prim.GetAttr("axis") == nullptr) {
MS_LOG(WARNING) << "get axis failed";
MS_LOG(INFO) << "BiasAdd's attr axis is set to default";
attr->axis = {1};
} else {
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));


+ 2
- 0
mindspore/lite/src/ops/lstm.cc View File

@@ -84,6 +84,8 @@ int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
out_shape[2] = hidden_size;
if (GetBidirection()) {
out_shape.insert(out_shape.begin() + 1, 2);
} else {
out_shape.insert(out_shape.begin() + 1, 1);
}
output->set_shape(out_shape);
// set hidden state, cell state


+ 3
- 3
mindspore/lite/src/ops/power.cc View File

@@ -56,19 +56,19 @@ int Power::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
}

if (prim.GetAttr("scale") == nullptr) {
MS_LOG(WARNING) << "get scale failed";
MS_LOG(INFO) << "Power's attr scale is set to default";
attr->scale = 1.0f;
} else {
attr->scale = GetValue<float>(prim.GetAttr("scale"));
}
if (prim.GetAttr("power") == nullptr) {
MS_LOG(WARNING) << "get power failed";
MS_LOG(INFO) << "Power's attr power is set to default";
attr->power = 1.0f;
} else {
attr->power = GetValue<float>(prim.GetAttr("power"));
}
if (prim.GetAttr("shift") == nullptr) {
MS_LOG(WARNING) << "get shift failed";
MS_LOG(INFO) << "Power's attr shift is set to default";
attr->shift = 0;
} else {
attr->shift = GetValue<float>(prim.GetAttr("shift"));


+ 1
- 1
mindspore/lite/src/ops/squeeze.cc View File

@@ -47,7 +47,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
return RET_ERROR;
}
if (prim.GetAttr("axis") == nullptr) {
MS_LOG(WARNING) << "get axis failed";
MS_LOG(INFO) << "Squeeze's attr xis is set to default";
attr->axis = {0};
} else {
int axis = GetValue<int>(prim.GetAttr("axis"));


+ 1
- 1
mindspore/lite/src/ops/tile.cc View File

@@ -53,7 +53,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
return RET_ERROR;
}
if (prim.GetAttr("dims") == nullptr) {
MS_LOG(WARNING) << "get dims failed";
MS_LOG(INFO) << "Tile's attr dims is set to default";
attr->dims = {1};
} else {
attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims"));


+ 1
- 0
mindspore/lite/src/ops/transpose.cc View File

@@ -124,6 +124,7 @@ int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
if (!GetInferFlag()) {
return RET_OK;
}
MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum);
MS_ASSERT(outputs_.size() == kSingleNum);

int conjugate = GetConjugate();


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc View File

@@ -116,6 +116,7 @@ int TransposeFp32Run(void *cdata, int task_id) {
}

int TransposeCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() == 1 || in_tensors_.size() == 2);
MS_ASSERT(out_tensors_.size() == 1);
auto &in_tensor = in_tensors_.front();
auto &out_tensor = out_tensors_.front();


+ 13
- 11
mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc View File

@@ -15,7 +15,9 @@
*/

#include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h"
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
#include <memory>
#include <numeric>

namespace mindspore {
namespace lite {
@@ -130,21 +132,21 @@ STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
}

const auto &onnx_pow_power = onnx_node.input(1);
auto nodeIter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_pow_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_pow_power; });
if (nodeIter == onnx_graph.node().end()) {
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_pow_power);
if (index == -1) {
MS_LOG(ERROR) << "can not find node: " << onnx_pow_power;
return RET_ERROR;
}
const float *pW = nullptr;
for (const auto &attrPower : nodeIter->attribute()) {
if (attrPower.name() == "value") {
const auto &t = attrPower.t();
pW = reinterpret_cast<const float *>(t.raw_data().data());
}
auto pow_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
if (std::accumulate(pow_attr->dims.begin(), pow_attr->dims.end(), 1, std::multiplies<int>()) != 1) {
MS_LOG(ERROR) << "the exponent element num is bigger than 1, which don't support now.";
return RET_NOT_SUPPORT;
}
attr->power = *pW;
if (pow_attr->data.data() == nullptr) {
MS_LOG(ERROR) << "power's attr pow can't be obtained.";
return RET_INVALID_OP_ATTR;
}
attr->power = *reinterpret_cast<float *>(pow_attr->data.data());
attr->scale = 1.0f;
attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power;


+ 1
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc View File

@@ -27,7 +27,7 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
}
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
if (deDepthwiseConv2DParam == nullptr) {
MS_LOG(WARNING) << "new op failed";
MS_LOG(ERROR) << "new op failed";
return false;
}
deDepthwiseConv2DParam->format = attr->format;


+ 1
- 2
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -374,8 +374,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
} else if (utils::isa<ValueNodePtr>(n)) {
value_node = utils::cast<ValueNodePtr>(n);
} else {
MS_LOG(ERROR) << "only value node or cnode has type";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
MS_LOG(INFO) << "only value node or cnode has type";
return schema::PrimitiveType_NONE;
}
if (value_node == nullptr) {


Loading…
Cancel
Save