Browse Source

add NHWC format requirement of topk ops

tags/v1.1.0
liuwenhao4 5 years ago
parent
commit
1957afa243
2 changed files with 6 additions and 1 deletions
  1. +4
    -0
      mindspore/lite/src/ops/topk.cc
  2. +2
    -1
      mindspore/lite/tools/common/node_util.cc

+ 4
- 0
mindspore/lite/src/ops/topk.cc View File

@@ -60,6 +60,10 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "topk only support NHWC now!";
return RET_FORMAT_ERR;
}
auto output0 = outputs_.front();
MS_ASSERT(output0 != nullptr);
auto output1 = outputs_.at(1);


+ 2
- 1
mindspore/lite/tools/common/node_util.cc View File

@@ -50,7 +50,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_InstanceNorm,
schema::PrimitiveType_SpaceToDepth,
schema::PrimitiveType_DepthToSpace};
schema::PrimitiveType_DepthToSpace,
schema::PrimitiveType_TopK};

static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
#ifdef SUPPORT_TRAIN


Loading…
Cancel
Save