| @@ -406,6 +406,30 @@ int Benchmark::RunBenchmark() { | |||||
| std::cout << "CompileGraph failed while running ", model_name.c_str(); | std::cout << "CompileGraph failed while running ", model_name.c_str(); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (!flags_->input_shape_list_.empty()) { | |||||
| std::vector<std::vector<int>> input_shapes; | |||||
| std::string input_dims_list = flags_->input_shape_list_; | |||||
| while (!input_dims_list.empty()) { | |||||
| auto position = | |||||
| input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length(); | |||||
| std::string input_dims = input_dims_list.substr(0, position); | |||||
| std::vector<int> input_shape; | |||||
| while (!input_dims.empty()) { | |||||
| auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length(); | |||||
| std::string dim = input_dims.substr(0, pos); | |||||
| input_shape.emplace_back(std::stoi(dim)); | |||||
| input_dims = input_dims.substr(pos); | |||||
| } | |||||
| input_shapes.emplace_back(input_shape); | |||||
| input_dims_list = input_dims_list.substr(position); | |||||
| } | |||||
| ret = session_->Resize(session_->GetInputs(), input_shapes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Input tensor resize failed."; | |||||
| std::cout << "Input tensor resize failed."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| model->Free(); | model->Free(); | ||||
| ms_inputs_ = session_->GetInputs(); | ms_inputs_ = session_->GetInputs(); | ||||
| auto end_prepare_time = GetTimeUs(); | auto end_prepare_time = GetTimeUs(); | ||||
| @@ -70,6 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser { | |||||
| AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType", | AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType", | ||||
| "Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT"); | "Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT"); | ||||
| AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); | AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); | ||||
| AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes", | |||||
| "Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", ""); | |||||
| } | } | ||||
| ~BenchmarkFlags() override = default; | ~BenchmarkFlags() override = default; | ||||
| @@ -86,6 +88,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { | |||||
| InDataType in_data_type_; | InDataType in_data_type_; | ||||
| std::string in_data_type_in_ = "bin"; | std::string in_data_type_in_ = "bin"; | ||||
| int cpu_bind_mode_ = 1; | int cpu_bind_mode_ = 1; | ||||
| std::string input_shape_list_; | |||||
| // MarkPerformance | // MarkPerformance | ||||
| int loop_count_; | int loop_count_; | ||||
| int num_threads_; | int num_threads_; | ||||
| @@ -26,6 +26,9 @@ using mindspore::lite::Tensor; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | namespace { | ||||
| constexpr int DEFAULT_DIM_VALUE = -1; | |||||
| } | |||||
| namespace { | |||||
| std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs, | std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs, | ||||
| const schema::PrimitiveType node_type) { | const schema::PrimitiveType node_type) { | ||||
| std::vector<Tensor *> lite_tensors; | std::vector<Tensor *> lite_tensors; | ||||
| @@ -85,6 +88,15 @@ void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> outp | |||||
| } // namespace | } // namespace | ||||
| STATUS InferShapePass::Run(MetaGraphT *graph) { | STATUS InferShapePass::Run(MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| for (auto idx : graph->inputIndex) { | |||||
| auto input_tensor = graph->allTensors[idx].get(); | |||||
| for (auto &dim : input_tensor->dims) { | |||||
| if (dim == 0) { | |||||
| MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to 32 as a default value."; | |||||
| dim = DEFAULT_DIM_VALUE; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| auto &node = *iter; | auto &node = *iter; | ||||
| auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type); | auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type); | ||||
| @@ -41,7 +41,14 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "value") { | if (attribute_name == "value") { | ||||
| attr->value = static_cast<int32_t>(onnx_node_attr.i()); | |||||
| if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) { | |||||
| auto tensor = onnx_node_attr.t(); | |||||
| if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) { | |||||
| attr->value = onnx_node_attr.f(); | |||||
| } else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) { | |||||
| attr->value = static_cast<int32_t>(onnx_node_attr.i()); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -66,14 +66,14 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "kernel_shape") { | if (attribute_name == "kernel_shape") { | ||||
| if (onnx_node_attr.ints_size() == 2) { | if (onnx_node_attr.ints_size() == 2) { | ||||
| attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| } | } | ||||
| } | } | ||||
| if (attribute_name == "strides") { | if (attribute_name == "strides") { | ||||
| if (onnx_node_attr.ints_size() == 2) { | if (onnx_node_attr.ints_size() == 2) { | ||||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| } | } | ||||
| } | } | ||||
| if (attribute_name == "auto_pad") { | if (attribute_name == "auto_pad") { | ||||