From 205d101b9f32542d6731d9c141a60f5c3d294d55 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Fri, 22 Jan 2021 11:06:08 +0800 Subject: [PATCH] fix PackOp implemant && gather infershape --- mindspore/lite/src/common/graph_util.cc | 8 ++++++++ mindspore/lite/src/common/graph_util.h | 2 ++ mindspore/lite/src/lite_session.cc | 2 +- mindspore/lite/src/ops/gather.cc | 3 +++ mindspore/lite/src/scheduler.cc | 2 +- mindspore/lite/src/scheduler.h | 6 ------ 6 files changed, 15 insertions(+), 8 deletions(-) diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc index 11c1db155f..d43fb7405a 100644 --- a/mindspore/lite/src/common/graph_util.cc +++ b/mindspore/lite/src/common/graph_util.cc @@ -81,5 +81,13 @@ std::vector GetLinkedPostNodeIdx(const lite::Model *model, const size_t } return post_node_idxes; } + +bool IsPackedOp(schema::PrimitiveType op_type) { + static std::vector packed_ops = { + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm}; + return IsContain(packed_ops, op_type); +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/common/graph_util.h b/mindspore/lite/src/common/graph_util.h index b6c566771d..a158d751fa 100644 --- a/mindspore/lite/src/common/graph_util.h +++ b/mindspore/lite/src/common/graph_util.h @@ -36,6 +36,8 @@ std::vector GetGraphInputNodes(const lite::Model *model); std::vector GetGraphOutputNodes(const lite::Model *model); std::vector GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor_idx); + +bool IsPackedOp(schema::PrimitiveType op_type); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 48e415f109..6348521714 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -49,7 +49,7 @@ static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) { auto node = model->all_nodes_[post_node_idx]; MS_ASSERT(node != nullptr); - return IsContain(packed_op, static_cast(node->primitive_->Type())); + return IsPackedOp(static_cast(node->primitive_->Type())); }); } diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index f4a1da13ee..9a7c6a7645 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -112,6 +112,9 @@ int Gather::InferShape(std::vector inputs_, std::vector outp auto output = outputs_.front(); MS_ASSERT(input != nullptr); output->set_data_type(input->data_type()); + if (this->quant_type() == schema::QuantType_WeightQuant) { + output->set_data_type(kNumberTypeFloat32); + } output->set_format(input->format()); if (!infer_flag()) { return RET_INFER_INVALID; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index d3dbbbe9db..ff1d1a1623 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -188,7 +188,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in if (primitive->quant_type() == schema::QuantType_WeightQuant) { data_type = kNumberTypeFloat32; } - if (!IsContain(packed_op, (schema::PrimitiveType)primitive->Type())) { + if (!IsPackedOp((schema::PrimitiveType)primitive->Type())) { need_restore = false; } kernel::KernelKey desc{kCPU, data_type, static_cast(primitive->Type())}; diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index 4866502838..bd5c9fac17 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -26,12 +26,6 @@ #include "src/ops/primitive_c.h" namespace mindspore::lite { - -static std::vector packed_op = { - schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, - schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, - schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm}; - class Scheduler { public: Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors)