diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index b5a6799db0..a4bdcc9eae 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -150,6 +150,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto tensor = metaGraphT->allTensors[input].get(); if (tensor->data.empty()) { tensor->nodeType = schema::NodeType_ValueNode; + tensor->format = schema::Format_NHWC; // tensor->refCount = lite::MSCONST_WEIGHT_REFCOUNT; metaGraphT->inputIndex.emplace_back(input); } diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index 6b9252cd39..c8a8e4e0a8 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -36,6 +36,7 @@ int AddN::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); output->set_shape(input->shape()); output->set_data_type(input->data_type()); return RET_OK; diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index af94e597e7..d71910e438 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -40,6 +40,7 @@ int ArgMax::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); output->set_shape(output_shape); output->set_data_type(input->data_type()); return RET_OK; diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index 2323af643f..b501b14b1c 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -39,9 +39,9 @@ int ArgMin::InferShape(std::vector inputs_, std::vector output_shape(input->shape()); output_shape.erase(output_shape.begin() + axis); + output->SetFormat(input->GetFormat()); output->set_shape(output_shape); output->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/arithmetic.cc b/mindspore/lite/src/ops/arithmetic.cc index 2a6ce1320e..2b15e22608 100644 --- a/mindspore/lite/src/ops/arithmetic.cc +++ b/mindspore/lite/src/ops/arithmetic.cc @@ -39,7 +39,7 @@ int Arithmetic::InferShape(std::vector inputs_, std::vectorshape(); auto input_shape1 = input1->shape(); - + auto format = input0->GetFormat(); in_shape0_.resize(5); in_shape1_.resize(5); out_shape_.resize(5); @@ -57,6 +57,7 @@ int Arithmetic::InferShape(std::vector inputs_, std::vectorGetFormat(); } else if (input_shape0.size() > input_shape1.size()) { ndim_ = input_shape0.size(); auto fill_dim_num = input_shape0.size() - input_shape1.size(); @@ -93,7 +94,7 @@ int Arithmetic::InferShape(std::vector inputs_, std::vectorSetFormat(format); output->set_shape(output_shape); output->set_data_type(input0->data_type()); return RET_OK; diff --git a/mindspore/lite/src/ops/arithmetic_self.cc b/mindspore/lite/src/ops/arithmetic_self.cc index 567a190f6a..3d2210e746 100644 --- a/mindspore/lite/src/ops/arithmetic_self.cc +++ b/mindspore/lite/src/ops/arithmetic_self.cc @@ -26,9 +26,11 @@ int ArithmeticSelf::InferShape(std::vector inputs_, std::vecto MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); + + output->SetFormat(input->GetFormat()); output->set_shape(input->shape()); output->set_data_type(input->data_type()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc index a3ca0b2b49..41412c58c3 100644 --- a/mindspore/lite/src/ops/batch_to_space.cc +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -85,9 +85,10 @@ int BatchToSpace::InferShape(std::vector inputs, std::vectorGet(0) - crops->Get(0) - crops->Get(1); output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3); output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index]; + + outputs[0]->SetFormat(input->GetFormat()); outputs[0]->set_shape(output_shape); outputs[0]->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index 225e34d614..51e5914677 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -58,9 +58,9 @@ int BroadcastTo::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); outputs[0]->set_shape(shape); outputs[0]->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 796f80cbee..13de84ff5e 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -44,9 +44,9 @@ int Cast::InferShape(std::vector inputs_, std::vectordstT(); return RET_INPUT_TENSOR_ERROR; } + output->SetFormat(input->GetFormat()); output->set_shape(input->shape()); output->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 2d966676d5..e69e2707a3 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -70,7 +70,8 @@ int Concat::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc index b58b8a27f4..dceab29f9b 100644 --- a/mindspore/lite/src/ops/crop.cc +++ b/mindspore/lite/src/ops/crop.cc @@ -32,7 +32,8 @@ int Crop::InferShape(std::vector inputs, std::vectorset_shape(inputs[1]->shape()); + outputs[0]->SetFormat(inputs[1]->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index 025c1ad360..f09fddfb58 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -23,7 +23,7 @@ namespace mindspore::lite { namespace { constexpr int kDepthToSpaceOutputNum = 1; constexpr int kDepthToSpaceInputNum = 1; -} +} // namespace int DepthToSpace::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive != nullptr); @@ -56,7 +56,8 @@ int DepthToSpace::InferShape(std::vector inputs, std::vectorset_shape(output_shape); outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index 588710f886..5b0391d654 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -45,7 +45,8 @@ int ExpandDims::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc index 361a5e2b8d..f4bd0c1952 100644 --- a/mindspore/lite/src/ops/fill.cc +++ b/mindspore/lite/src/ops/fill.cc @@ -42,7 +42,8 @@ int Fill::InferShape(std::vector inputs_, std::vectordims()->begin(), fill_prim->dims()->end()); output->set_shape(output_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc index c2264afcf9..bde0cd16c5 100644 --- a/mindspore/lite/src/ops/flatten.cc +++ b/mindspore/lite/src/ops/flatten.cc @@ -43,7 +43,8 @@ int Flatten::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/fullconnection.cc b/mindspore/lite/src/ops/fullconnection.cc index 0b44faecfd..7b4b1e051f 100644 --- a/mindspore/lite/src/ops/fullconnection.cc +++ b/mindspore/lite/src/ops/fullconnection.cc @@ -56,7 +56,8 @@ int FullConnection::InferShape(std::vector inputs_, std::vecto out_shape[fc_prim->axis()] = input1->shape()[0]; output->set_shape(out_shape); output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 0e5cd61918..328de9ba2f 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -71,7 +71,8 @@ int Gather::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc index 4f5598817b..681e2d207b 100644 --- a/mindspore/lite/src/ops/gather_nd.cc +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -59,7 +59,8 @@ int GatherNd::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index d7cb772f41..2d031378bf 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -57,7 +57,8 @@ int MatMul::InferShape(std::vector inputs_, std::vectorset_shape(y_shape); output->set_data_type(input0->data_type()); + output->SetFormat(input0->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index eb96edad89..878813c995 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -67,6 +67,8 @@ int OneHot::InferShape(std::vector inputs, std::vectorset_data_type(on_value->data_type()); + output->SetFormat(on_value->GetFormat()); + return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.cc b/mindspore/lite/src/ops/ops.cc index 003c6e8db0..9719ddc061 100644 --- a/mindspore/lite/src/ops/ops.cc +++ b/mindspore/lite/src/ops/ops.cc @@ -140,7 +140,8 @@ int Primitive::InferShape(std::vector inputs_, std::vectorset_shape(input->shape()); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 8604da24e3..3bdbe04235 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -55,9 +55,9 @@ int Pad::InferShape(std::vector inputs, std::vectorSetFormat(input->GetFormat()); output->set_shape(output_shape); output->set_data_type(input->data_type()); return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index c25a558bbc..20745e7cd0 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -74,6 +74,7 @@ int Pooling::InferShape(std::vector inputs_, std::vectorset_shape(input_shape); output->set_data_type(input->data_type()); + // todo: temp fix output->SetFormat(schema::Format_NHWC); return RET_OK; diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index 4adafa2689..53180a8d51 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -34,7 +34,8 @@ int Range::InferShape(std::vector inputs_, std::vectorset_shape(in_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc index c7b70930b1..5939396d16 100644 --- a/mindspore/lite/src/ops/rank.cc +++ b/mindspore/lite/src/ops/rank.cc @@ -29,7 +29,8 @@ int Rank::InferShape(std::vector inputs_, std::vector in_shape(1, 1); output->set_shape(in_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 888a61df87..76ce819977 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -73,6 +73,8 @@ int Reduce::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 3e2a9c6eef..1358769bb3 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -114,7 +114,8 @@ int Reshape::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index 7dd387c636..9b7edd2d9e 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -45,7 +45,8 @@ int Resize::InferShape(std::vector inputs_, std::vectorChannel()); output->set_shape(output_shape); output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/scatter_nd.cc b/mindspore/lite/src/ops/scatter_nd.cc index cf9f4dfbc3..446a37d872 100644 --- a/mindspore/lite/src/ops/scatter_nd.cc +++ b/mindspore/lite/src/ops/scatter_nd.cc @@ -57,7 +57,8 @@ int ScatterND::InferShape(std::vector inputs_, std::vector out_shape(shape_data, shape_data + sizeof(shape_data) / sizeof(shape_data[0])); output->set_shape(out_shape); output->set_data_type(update->data_type()); + output->SetFormat(update->GetFormat()); + return RET_OK; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 994fabb1cc..c1fe1d396d 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -23,7 +23,7 @@ namespace mindspore::lite { namespace { constexpr int kSliceInputNum = 1; constexpr int kSliceOutputNum = 1; -} +} // namespace int Slice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive != nullptr); @@ -47,13 +47,13 @@ int Slice::InferShape(std::vector inputs, std::vector