Browse Source

!11521 fix packedOp implements && gather infershape

From: @xutianchun
Reviewed-by: @hangangqiang,@HilbertDavid
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
794ab0dfcf
6 changed files with 15 additions and 8 deletions
  1. +8
    -0
      mindspore/lite/src/common/graph_util.cc
  2. +2
    -0
      mindspore/lite/src/common/graph_util.h
  3. +1
    -1
      mindspore/lite/src/lite_session.cc
  4. +3
    -0
      mindspore/lite/src/ops/gather.cc
  5. +1
    -1
      mindspore/lite/src/scheduler.cc
  6. +0
    -6
      mindspore/lite/src/scheduler.h

+ 8
- 0
mindspore/lite/src/common/graph_util.cc View File

@@ -81,5 +81,13 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t
}
return post_node_idxes;
}

bool IsPackedOp(schema::PrimitiveType op_type) {
static std::vector<schema::PrimitiveType> packed_ops = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm};
return IsContain(packed_ops, op_type);
}
} // namespace lite
} // namespace mindspore

+ 2
- 0
mindspore/lite/src/common/graph_util.h View File

@@ -36,6 +36,8 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model);
std::vector<size_t> GetGraphOutputNodes(const lite::Model *model);

std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor_idx);

bool IsPackedOp(schema::PrimitiveType op_type);
} // namespace lite
} // namespace mindspore



+ 1
- 1
mindspore/lite/src/lite_session.cc View File

@@ -49,7 +49,7 @@ static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor
return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) {
auto node = model->all_nodes_[post_node_idx];
MS_ASSERT(node != nullptr);
return IsContain(packed_op, static_cast<schema::PrimitiveType>(node->primitive_->Type()));
return IsPackedOp(static_cast<schema::PrimitiveType>(node->primitive_->Type()));
});
}



+ 3
- 0
mindspore/lite/src/ops/gather.cc View File

@@ -112,6 +112,9 @@ int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
auto output = outputs_.front();
MS_ASSERT(input != nullptr);
output->set_data_type(input->data_type());
if (this->quant_type() == schema::QuantType_WeightQuant) {
output->set_data_type(kNumberTypeFloat32);
}
output->set_format(input->format());
if (!infer_flag()) {
return RET_INFER_INVALID;


+ 1
- 1
mindspore/lite/src/scheduler.cc View File

@@ -188,7 +188,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
if (primitive->quant_type() == schema::QuantType_WeightQuant) {
data_type = kNumberTypeFloat32;
}
if (!IsContain(packed_op, (schema::PrimitiveType)primitive->Type())) {
if (!IsPackedOp((schema::PrimitiveType)primitive->Type())) {
need_restore = false;
}
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};


+ 0
- 6
mindspore/lite/src/scheduler.h View File

@@ -26,12 +26,6 @@
#include "src/ops/primitive_c.h"

namespace mindspore::lite {

static std::vector<schema::PrimitiveType> packed_op = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm};

class Scheduler {
public:
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors)


Loading…
Cancel
Save