From: @zhengjun10 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -76,6 +76,12 @@ inline void Float32ToInt16(const float *input, int16_t *output, int number) { | |||
| output[i] = (int16_t)input[i]; | |||
| } | |||
| } | |||
| inline void BoolToInt32(const bool *input, int32_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = (int32_t)input[i]; | |||
| } | |||
| } | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -70,7 +70,7 @@ template <typename T> | |||
| void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape) { | |||
| int input_count = inputs[0]->ElementsNum(); | |||
| int input_dim_size = inputs[0]->shape().empty() ? 1 : inputs[0]->shape().size(); | |||
| (*out_shape)[0] = input_dim_size; | |||
| out_shape->emplace_back(input_dim_size); | |||
| int nonzero_size = 0; | |||
| for (int i = 0; i < input_count; i++) { | |||
| if (static_cast<int>(data[i]) != 0) { | |||
| @@ -78,35 +78,34 @@ void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<in | |||
| } | |||
| } | |||
| if (nonzero_size == 0) { | |||
| *out_shape = {}; | |||
| return; | |||
| } else { | |||
| (*out_shape)[1] = nonzero_size / input_dim_size; | |||
| out_shape->emplace_back(nonzero_size); | |||
| } | |||
| } | |||
| int NonZero::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| MS_ASSERT(inputs_.size() == 1); | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto input_tensor = inputs_.front(); | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->set_format(input->format()); | |||
| output->set_data_type(TypeId::kNumberTypeInt32); | |||
| output->set_format(input_tensor->format()); | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| std::vector<int> out_shape; | |||
| if (inputs_.size() == kSingleNum) { | |||
| auto input_tensor = inputs_.at(0); | |||
| if (input_tensor->data_c() == nullptr) { | |||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| switch (input_tensor->data_type()) { | |||
| case kNumberTypeFloat: { | |||
| auto data = reinterpret_cast<float *>(input_tensor->MutableData()); | |||
| CalShape<float>(data, inputs_, &out_shape); | |||
| case kNumberTypeBool: { | |||
| auto data = reinterpret_cast<bool *>(input_tensor->MutableData()); | |||
| CalShape<bool>(data, inputs_, &out_shape); | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "NonZero weight tensor has unsupported dataType: " << input_tensor->data_type(); | |||
| @@ -168,6 +168,7 @@ | |||
| #include "src/ops/random_standard_normal.h" | |||
| #include "src/ops/invert_permutation.h" | |||
| #include "src/ops/crop_and_resize.h" | |||
| #include "src/ops/nonzero.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -1025,6 +1026,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) RandomStandardNormal(primitive); | |||
| case schema::PrimitiveType_CropAndResize: | |||
| return new (std::nothrow) CropAndResize(primitive); | |||
| case schema::PrimitiveType_NonZero: | |||
| return new (std::nothrow) NonZero(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| @@ -152,7 +152,9 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| } | |||
| } | |||
| } | |||
| if (input0->tensors_data_type() == kTypeUnknown) { | |||
| input0->set_tensors_data_type(value_tensor->data_type()); | |||
| } | |||
| out_shape[index] = value_tensor->shape(); | |||
| output0->MallocTensorListData(input0->tensors_data_type(), out_shape); | |||
| return RET_OK; | |||
| @@ -172,7 +172,7 @@ int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector | |||
| int TensorListStack::MergeShape(const std::vector<int> &shape) { | |||
| size_t dim0 = shape.size(); | |||
| size_t dim1 = output_shape_.size(); | |||
| if (dim1 >= unKnownRank_) { | |||
| if (dim1 >= unKnownRank_ || output_shape_[0] == -1) { | |||
| output_shape_ = shape; | |||
| return RET_OK; | |||
| } | |||
| @@ -150,6 +150,13 @@ int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o | |||
| for (size_t i = 0; i < perm.size(); ++i) { | |||
| out_shape.at(i) = in_shape.at(perm.at(i)); | |||
| } | |||
| if (perm.empty()) { | |||
| auto shape_size = in_shape.size(); | |||
| out_shape.resize(shape_size); | |||
| for (size_t i = 0; i < shape_size; ++i) { | |||
| out_shape[shape_size - i - 1] = in_shape[i]; | |||
| } | |||
| } | |||
| output->set_shape(out_shape); | |||
| return RET_OK; | |||
| } | |||
| @@ -85,13 +85,28 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_ | |||
| int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { | |||
| // shape may change, because tensors.size() can be change in RunGraph | |||
| if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || | |||
| !(dst_tensor->element_shape() == src_tensor->element_shape() || | |||
| (dst_tensor->element_shape().empty() && src_tensor->element_shape().empty())) || | |||
| dst_tensor->tensors_data_type() != src_tensor->tensors_data_type()) { | |||
| MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible"; | |||
| if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) { | |||
| MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; | |||
| return RET_ERROR; | |||
| } | |||
| if (dst_tensor->element_shape() != src_tensor->element_shape()) { | |||
| MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible"; | |||
| return RET_ERROR; | |||
| } | |||
| auto update_data_type = kTypeUnknown; | |||
| auto dst_tensor_data_type = dst_tensor->tensors_data_type(); | |||
| auto src_tensor_data_type = src_tensor->tensors_data_type(); | |||
| if (dst_tensor_data_type != src_tensor_data_type) { | |||
| if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) { | |||
| MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible"; | |||
| return RET_ERROR; | |||
| } | |||
| update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type; | |||
| } | |||
| if (update_data_type != kTypeUnknown) { | |||
| src_tensor->set_tensors_data_type(update_data_type); | |||
| dst_tensor->set_tensors_data_type(update_data_type); | |||
| } | |||
| if (src_tensor->root_tensor() == nullptr) { | |||
| dst_tensor->CopyTensorList(*src_tensor, false); | |||
| src_tensor->set_tensors({}); | |||
| @@ -88,6 +88,9 @@ int CastCPUKernel::DoCast(int thread_id) { | |||
| } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt16) { | |||
| Float32ToInt16(reinterpret_cast<float *>(input->data_c()) + offset, | |||
| reinterpret_cast<int16_t *>(output_data) + offset, data_num); | |||
| } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { | |||
| BoolToInt32(reinterpret_cast<bool *>(input->data_c()) + offset, reinterpret_cast<int32_t *>(output_data) + offset, | |||
| data_num); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | |||
| return RET_ERROR; | |||
| @@ -39,7 +39,7 @@ int NonZeroCPUKernel::ReSize() { return RET_OK; } | |||
| int NonZeroCPUKernel::Run() { | |||
| auto in_tensor = in_tensors_.front(); | |||
| auto out_tensor = out_tensors_.front(); | |||
| auto input_data = reinterpret_cast<float *>(in_tensor->MutableData()); | |||
| auto input_data = reinterpret_cast<bool *>(in_tensor->MutableData()); | |||
| auto output_data = reinterpret_cast<int *>(out_tensor->MutableData()); | |||
| auto input_dim_size = in_tensor->shape().size(); | |||
| if (out_tensor->shape().size() != 2) { | |||
| @@ -50,56 +50,22 @@ int NonZeroCPUKernel::Run() { | |||
| int non_zero_count = 0; | |||
| std::vector coordiate_values(in_tensor->shape().size(), 0); | |||
| for (int i = 0; i < in_tensor->ElementsNum(); i += 1) { | |||
| if (input_data[i] != 0) { | |||
| if (input_data[i]) { | |||
| for (size_t j = 0; j < input_dim_size; j++) { | |||
| output_data[non_zero_count + j * non_zero_nums] = coordiate_values[j]; | |||
| } | |||
| non_zero_count++; | |||
| } | |||
| for (int idx = input_dim_size - 1; idx >= 0; --idx) { | |||
| if (coordiate_values[idx] != in_tensor->shape()[idx] - 1) { | |||
| coordiate_values[idx] = coordiate_values[idx] + 1; | |||
| for (size_t idx = input_dim_size; idx >= 1; --idx) { | |||
| if (coordiate_values[idx - 1] != in_tensor->shape()[idx - 1] - 1) { | |||
| coordiate_values[idx - 1] = coordiate_values[idx - 1] + 1; | |||
| break; | |||
| } | |||
| coordiate_values[idx] = 0; | |||
| coordiate_values[idx - 1] = 0; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuNonZeroFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Input context is nullptr!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| if (ctx->thread_num_ == 0) { | |||
| MS_LOG(ERROR) << "context thread num is 0!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) NonZeroCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new NonZeroCPUKernel fail!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonZero, CpuNonZeroFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_NonZero, LiteKernelCreator<NonZeroCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -30,9 +30,23 @@ namespace mindspore::kernel { | |||
| int TensorListSetItemCPUKernel::Init() { return RET_OK; } | |||
| int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) { | |||
| output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]); | |||
| int new_tensors_size = origin_size + 1; | |||
| output0_->set_shape({new_tensors_size}); | |||
| std::vector<std::vector<int>> out_shape; | |||
| out_shape.resize(new_tensors_size, in_tensors_[2]->shape()); | |||
| auto ret = output0_->MallocTensorListData(in_tensors_[2]->data_type(), out_shape); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "increment output size malloc tensorlist data error"; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TensorListSetItemCPUKernel::Run() { | |||
| input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]); | |||
| if (dtype_ != input0_->tensors_data_type()) { | |||
| if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) { | |||
| MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -47,8 +61,10 @@ int TensorListSetItemCPUKernel::Run() { | |||
| } | |||
| index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0]; | |||
| if (index_ < 0 || index_ > dim0) { | |||
| MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be in [0, " << dim0 << "]!"; | |||
| return RET_ERROR; | |||
| if (IncrementOutputSize(output0_->shape()[0]) != RET_OK) { | |||
| MS_LOG(ERROR) << "Resizeoutput Error ,index tensor:[" << index_ << "] must be in [0, " << dim0 << "]!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| input2_ = in_tensors_[2]; | |||
| MS_ASSERT(input2_ != nullptr); | |||
| @@ -57,6 +73,13 @@ int TensorListSetItemCPUKernel::Run() { | |||
| } | |||
| output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]); | |||
| MS_ASSERT(output0_ != nullptr); | |||
| // new loop count | |||
| if (output0_->ElementsNum() != static_cast<int>(output0_->tensors().size()) && output0_->tensors().empty()) { | |||
| if (IncrementOutputSize(0) != RET_OK) { | |||
| MS_LOG(ERROR) << "Resizeoutput Error!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| // copy each tensor in tensors_ | |||
| for (int i = 0; i < output0_->ElementsNum(); ++i) { | |||
| if (i == index_) { | |||
| @@ -92,10 +115,6 @@ int TensorListSetItemCPUKernel::Run() { | |||
| } | |||
| if (src->data_type() != kTypeUnknown) { | |||
| if (src->Size() != dst->Size()) { | |||
| MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = lite::Tensor::CopyTensorData(*src, dst); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "CopyTensorData[" << i << "] is failed!"; | |||
| @@ -36,6 +36,7 @@ class TensorListSetItemCPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int IncrementOutputSize(int origin_size); | |||
| private: | |||
| lite::TensorList *input0_ = nullptr; | |||
| @@ -31,7 +31,7 @@ using mindspore::schema::PrimitiveType_TensorListStack; | |||
| namespace mindspore::kernel { | |||
| int TensorListStackCPUKernel::CheckParam() { | |||
| if (input0_->tensors_data_type() != dtype_) { | |||
| if (dtype_ != kTypeUnknown && input0_->tensors_data_type() != dtype_) { | |||
| MS_LOG(ERROR) << "in_tensors_[0].tensors_data_type:[" << input0_->tensors_data_type() << "] must be equal " | |||
| << "param.data_type:[" << dtype_ << "]"; | |||
| return RET_ERROR; | |||
| @@ -113,6 +113,16 @@ int TensorListStackCPUKernel::MergeElementShape() { | |||
| int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) { | |||
| size_t dim0 = shape.size(); | |||
| size_t dim1 = output_shape_.size(); | |||
| // unknown shape use input element shape | |||
| if (dim1 != 0 && output_shape_[0] == -1) { | |||
| if (dim0 == 0) { | |||
| output_shape_.clear(); | |||
| output_shape_.emplace_back(1); | |||
| } else { | |||
| output_shape_ = shape; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| if (dim1 != dim0) { | |||
| MS_LOG(ERROR) << "shape.size():" << dim1 << " must be equal output_shape_.size():" << dim0; | |||
| return RET_ERROR; | |||
| @@ -38,6 +38,10 @@ int UnsqueezeCPUKernel::Init() { | |||
| int UnsqueezeCPUKernel::ReSize() { | |||
| data_size_ = in_tensors_.at(0)->ElementsNum(); | |||
| thread_sz_count_ = MSMIN(context_->thread_num_, data_size_); | |||
| if (thread_sz_count_ == 0) { | |||
| thread_sz_stride_ = 0; | |||
| return RET_OK; | |||
| } | |||
| thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); | |||
| return RET_OK; | |||
| } | |||
| @@ -35,3 +35,4 @@ ml_video_edit_img_segment_adaptise.pb;2 | |||
| ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 | |||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 | |||
| ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 | |||
| tiny-yolov3-11.onnx;2;1,416,416,3:1,2 | |||
| @@ -582,10 +582,9 @@ function Run_x86() { | |||
| if [[ $line == \#* ]]; then | |||
| continue | |||
| fi | |||
| model_name=${line%%;*} | |||
| model_name_len=${#model_name} | |||
| input_params=${line:model_name_len+1} | |||
| input_num=${input_params%%;*} | |||
| model_name=`echo ${line} | awk -F ';' '{print $1}'` | |||
| input_num=`echo ${line} | awk -F ';' '{print $2}'` | |||
| input_shapes=`echo ${line} | awk -F ';' '{print $3}'` | |||
| input_files='' | |||
| output_file='' | |||
| data_path="/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/" | |||
| @@ -606,8 +605,8 @@ function Run_x86() { | |||
| echo ${model_name} >> "${run_x86_log_file}" | |||
| echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64' >> "{run_x86_log_file}" | |||
| cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64 || return 1 | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile='${output_file}' --loopCount=1 --warmUpLoopCount=0' >> "${run_x86_log_file}" | |||
| export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}.ms --inDataFile=${input_files} --benchmarkDataFile=${output_file} --loopCount=1 --warmUpLoopCount=0 >> "${run_x86_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file}' --loopCount=1 --warmUpLoopCount=0' >> "${run_x86_log_file}" | |||
| export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}.ms --inDataFile=${input_files} --inputShapes=${input_shapes} --benchmarkDataFile=${output_file} --loopCount=1 --warmUpLoopCount=0 >> "${run_x86_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='x86: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| @@ -851,9 +850,9 @@ function Run_x86_sse() { | |||
| if [[ $model_name == \#* ]]; then | |||
| continue | |||
| fi | |||
| model_name_len=${#model_name} | |||
| input_params=${line:model_name_len+1} | |||
| input_num=${input_params%%;*} | |||
| model_name=`echo ${line} | awk -F ';' '{print $1}'` | |||
| input_num=`echo ${line} | awk -F ';' '{print $2}'` | |||
| input_shapes=`echo ${line} | awk -F ';' '{print $3}'` | |||
| input_files='' | |||
| output_file='' | |||
| data_path="/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/" | |||
| @@ -874,8 +873,8 @@ function Run_x86_sse() { | |||
| echo ${model_name} >> "${run_x86_sse_log_file}" | |||
| echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64-sse' >> "{run_x86_sse_log_file}" | |||
| cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64-sse || return 1 | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile='${output_file}' --loopCount=1 --warmUpLoopCount=0' >> "${run_x86_sse_log_file}" | |||
| export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}.ms --inDataFile=${input_files} --benchmarkDataFile=${output_file} --loopCount=1 --warmUpLoopCount=0 >> "${run_x86_sse_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file}' --loopCount=1 --warmUpLoopCount=0' >> "${run_x86_sse_log_file}" | |||
| export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}.ms --inDataFile=${input_files} --inputShapes=${input_shapes} --benchmarkDataFile=${output_file} --loopCount=1 --warmUpLoopCount=0 >> "${run_x86_sse_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='x86_sse: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| @@ -1119,9 +1118,9 @@ function Run_x86_avx() { | |||
| if [[ $model_name == \#* ]]; then | |||
| continue | |||
| fi | |||
| model_name_len=${#model_name} | |||
| input_params=${line:model_name_len+1} | |||
| input_num=${input_params%%;*} | |||
| model_name=`echo ${line} | awk -F ';' '{print $1}'` | |||
| input_num=`echo ${line} | awk -F ';' '{print $2}'` | |||
| input_shapes=`echo ${line} | awk -F ';' '{print $3}'` | |||
| input_files='' | |||
| output_file='' | |||
| data_path="/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/" | |||
| @@ -1142,8 +1141,8 @@ function Run_x86_avx() { | |||
| echo ${model_name} >> "${run_x86_avx_log_file}" | |||
| echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64-avx' >> "{run_x86_avx_log_file}" | |||
| cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64-avx || return 1 | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile='${output_file}' --loopCount=1 --warmUpLoopCount=0' >> "${run_x86_avx_log_file}" | |||
| export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}.ms --inDataFile=${input_files} --benchmarkDataFile=${output_file} --loopCount=1 --warmUpLoopCount=0 >> "${run_x86_avx_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file}' --loopCount=1 --warmUpLoopCount=0' >> "${run_x86_avx_log_file}" | |||
| export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}.ms --inDataFile=${input_files} --inputShapes=${input_shapes} --benchmarkDataFile=${output_file} --loopCount=1 --warmUpLoopCount=0 >> "${run_x86_avx_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='x86_avx: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| @@ -1673,9 +1672,9 @@ function Run_arm64() { | |||
| if [[ $model_name == \#* ]]; then | |||
| continue | |||
| fi | |||
| model_name_len=${#model_name} | |||
| input_params=${line:model_name_len+1} | |||
| input_num=${input_params%%;*} | |||
| model_name=`echo ${line} | awk -F ';' '{print $1}'` | |||
| input_num=`echo ${line} | awk -F ';' '{print $2}'` | |||
| input_shapes=`echo ${line} | awk -F ';' '{print $3}'` | |||
| input_files='' | |||
| output_file='' | |||
| data_path="/data/local/tmp/input_output/" | |||
| @@ -1696,8 +1695,8 @@ function Run_arm64() { | |||
| fi | |||
| echo ${model_name} >> "${run_arm64_log_file}" | |||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile='${output_file} >> "${run_arm64_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile='${output_file} >> adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file} >> "${run_arm64_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'.ms --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file} >> adb_run_cmd.txt | |||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='arm64: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| @@ -66,7 +66,6 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||
| // for now - training is not supporting fuse operations | |||
| if (!config->trainModel) { | |||
| // remove quantdtype when awaretraining | |||
| fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||
| auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>(); | |||
| conv_bn_pass->SetFmkType(config->fmk); | |||
| @@ -147,6 +146,7 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt | |||
| int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | |||
| const converter::Flags *config) { | |||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | |||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||
| if (!config->trainModel) { | |||
| auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | |||
| inne_context_ptr->Init(); | |||
| @@ -143,6 +143,13 @@ STATUS InferShapePass::Run(MetaGraphT *graph) { | |||
| } | |||
| } | |||
| } | |||
| for (auto g_input_idx : graph->inputIndex) { | |||
| auto g_input_shape = graph->allTensors.at(g_input_idx)->dims; | |||
| if (std::find(g_input_shape.begin(), g_input_shape.end(), -1) != g_input_shape.end()) { | |||
| MS_LOG(INFO) << "InferShape shouldn't be done before runtime"; | |||
| return RET_OK; | |||
| } | |||
| } | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type); | |||
| @@ -131,9 +131,11 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { | |||
| contain_node_output_subgraphs.push_back(subgraph.get()); | |||
| } | |||
| } | |||
| std::set_intersection(contain_node_input_subgraphs.begin(), contain_node_input_subgraphs.end(), | |||
| contain_node_output_subgraphs.begin(), contain_node_output_subgraphs.end(), | |||
| inserter(contain_subgraphs, contain_subgraphs.begin())); | |||
| for (auto subgraph : contain_node_input_subgraphs) { | |||
| if (IsContain(contain_node_output_subgraphs, subgraph)) { | |||
| contain_subgraphs.emplace_back(subgraph); | |||
| } | |||
| } | |||
| if (contain_subgraphs.size() == 1) { | |||
| IncreaseSubgraphNodeIndices(i, graph); | |||
| contain_subgraphs[0]->nodeIndices.push_back(i); | |||
| @@ -165,56 +165,25 @@ STATUS SingleSwitchPass::BodyGraphVariableInput(std::vector<size_t> *variable_in | |||
| return RET_OK; | |||
| } | |||
| STATUS SingleSwitchPass::InsertMerge() { | |||
| // update body graph output | |||
| auto &body_fg = graph_->subGraph.at(second_subgraph_index_); | |||
| body_fg->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(), | |||
| body_to_cond_partial_node_->inputIndex.end()); | |||
| // remove body_to_cond_partial_node_ from second_graph_nodes_ | |||
| for (auto it = second_graph_nodes_.begin(); it != second_graph_nodes_.end();) { | |||
| if (*it == body_to_cond_partial_node_) { | |||
| it = second_graph_nodes_.erase(it); | |||
| } else { | |||
| it++; | |||
| } | |||
| } | |||
| // isolate body_to_cond_partial_node_ | |||
| IsolateUselessNode(body_to_cond_partial_node_, graph_); | |||
| std::vector<size_t> variable_input{}; | |||
| int ret = BodyGraphVariableInput(&variable_input); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "get body graph variable input failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| std::vector<size_t> const_input{}; | |||
| for (size_t i = 0; i < second_partial_node_->inputIndex.size(); i++) { | |||
| if (IsContain(variable_input, i)) { | |||
| continue; | |||
| } | |||
| const_input.push_back(i); | |||
| } | |||
| std::unique_ptr<schema::CNodeT> SingleSwitchPass::MakeMergeNode(const std::string &name, | |||
| const std::vector<size_t> &const_input) { | |||
| auto merge_node = std::make_unique<schema::CNodeT>(); | |||
| if (merge_node == nullptr) { | |||
| MS_LOG(ERROR) << "new CNodeT failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| merge_node->primitive = std::make_unique<PrimitiveT>(); | |||
| if (merge_node->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| merge_node->name = switch_node_->name + "-merge"; | |||
| merge_node->name = name; | |||
| merge_node->primitive->value.type = schema::PrimitiveType_Merge; | |||
| merge_node->primitive->value.value = new (std::nothrow) MergeT(); | |||
| if (merge_node->primitive->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new MergeT failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| // merge node output is same as switch | |||
| @@ -251,7 +220,46 @@ STATUS SingleSwitchPass::InsertMerge() { | |||
| merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); | |||
| } | |||
| } | |||
| return merge_node; | |||
| } | |||
| STATUS SingleSwitchPass::InsertMerge() { | |||
| // update body graph output | |||
| auto &body_fg = graph_->subGraph.at(second_subgraph_index_); | |||
| body_fg->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(), | |||
| body_to_cond_partial_node_->inputIndex.end()); | |||
| // remove body_to_cond_partial_node_ from second_graph_nodes_ | |||
| for (auto it = second_graph_nodes_.begin(); it != second_graph_nodes_.end();) { | |||
| if (*it == body_to_cond_partial_node_) { | |||
| it = second_graph_nodes_.erase(it); | |||
| } else { | |||
| it++; | |||
| } | |||
| } | |||
| // isolate body_to_cond_partial_node_ | |||
| IsolateUselessNode(body_to_cond_partial_node_, graph_); | |||
| std::vector<size_t> variable_input{}; | |||
| int ret = BodyGraphVariableInput(&variable_input); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "get body graph variable input failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| std::vector<size_t> const_input{}; | |||
| for (size_t i = 0; i < second_partial_node_->inputIndex.size(); i++) { | |||
| if (IsContain(variable_input, i)) { | |||
| continue; | |||
| } | |||
| const_input.push_back(i); | |||
| } | |||
| auto merge_node = MakeMergeNode(switch_node_->name + "-merge", const_input); | |||
| if (merge_node == nullptr) { | |||
| MS_LOG(ERROR) << "make merge node failed"; | |||
| return ret; | |||
| } | |||
| // insert merge node before the cond graph | |||
| std::map<int, int> cond_input_update_map{}; | |||
| for (size_t i = 0; i < first_partial_node_->inputIndex.size(); i++) { | |||
| @@ -591,6 +599,11 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche | |||
| output = subgraph_output_map.at(output); | |||
| } | |||
| } | |||
| for (auto &input : subgraph_node->inputIndex) { | |||
| if (subgraph_output_map.find(input) != subgraph_output_map.end()) { | |||
| input = subgraph_output_map.at(input); | |||
| } | |||
| } | |||
| } | |||
| std::vector<int> new_subgraph_outputs{}; | |||
| @@ -50,6 +50,7 @@ class SingleSwitchPass { | |||
| STATUS ConcatBodySubgraphInputAndOutput(); | |||
| bool IsLoop(); | |||
| STATUS InsertMerge(); | |||
| std::unique_ptr<schema::CNodeT> MakeMergeNode(const std::string &name, const std::vector<size_t> &const_in); | |||
| // function for if | |||
| STATUS InsertPartialAndMergeAfterSwitch(); | |||
| @@ -57,8 +57,10 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st | |||
| static auto root_func_manager = Manage(anf_root_graph_); | |||
| for (auto &subgraph : all_subgraphs_) { | |||
| subgraph->set_manager(root_func_manager); | |||
| subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||
| } | |||
| anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||
| return anf_root_graph_; | |||
| } | |||
| @@ -165,6 +167,7 @@ STATUS OnnxModelParser::ConvertGraphInputs(const onnx::GraphProto &onnx_graph, c | |||
| auto onnx_shape = input_value.type().tensor_type().shape().dim(); | |||
| std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector), | |||
| [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); }); | |||
| std::replace(shape_vector.begin(), shape_vector.end(), 0, -1); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| parameter->set_abstract(abstract_tensor); | |||
| parameter->set_name(input_value.name()); | |||
| @@ -746,6 +749,41 @@ STATUS AddIterNumsUpdateEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodeP | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::AddTensorListStackNode(const AnfNodePtr &root_while_node, const onnx::NodeProto &onnx_node, | |||
| int act_outputs_num, int body_output_size) { | |||
| auto &loop_node_name = onnx_node.name(); | |||
| auto root_anf_graph = root_while_node->func_graph(); | |||
| auto stack_elem_node = CreateConstParamter(root_anf_graph, -1); | |||
| stack_elem_node->set_name(loop_node_name + "_element_shape"); | |||
| for (int j = 0; j < act_outputs_num; j++) { | |||
| auto output_size = onnx_node.output_size(); | |||
| auto &loop_output_name = onnx_node.output(output_size - act_outputs_num + j); | |||
| auto &while_output_node = control_nodes_map_[loop_node_name]->at(loop_output_name); | |||
| auto stack_attr = std::make_unique<schema::TensorListStackT>(); | |||
| if (stack_attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_ERROR; | |||
| } | |||
| stack_attr->numElements = -1; | |||
| auto stack_value_node = CreateValueNode(stack_attr.release(), schema::PrimitiveType_TensorListStack); | |||
| std::vector<AnfNodePtr> stack_inputs = {stack_value_node, while_output_node, stack_elem_node}; | |||
| auto tensorlist_stack_cnode = root_anf_graph->NewCNode(stack_inputs); | |||
| if (tensorlist_stack_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "new cnode error"; | |||
| return RET_ERROR; | |||
| } | |||
| tensorlist_stack_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_stack_node_" + std::to_string(j)); | |||
| tensorlist_stack_cnode->set_abstract(stack_elem_node->abstract()); | |||
| // update getitem value output index | |||
| auto new_get_item_value = NewValueNode(MakeValue<int>(body_output_size - act_outputs_num + j)); | |||
| while_output_node->cast<CNodePtr>()->set_input(2, new_get_item_value); | |||
| // insert tensorliststack after while_output | |||
| (*control_nodes_map_[loop_node_name])[loop_output_name] = tensorlist_stack_cnode; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| // onnx loop scan_output need through tensorlist op,while node need add new inputs | |||
| STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs, | |||
| const std::string &loop_node_name, | |||
| @@ -874,34 +912,10 @@ STATUS OnnxModelParser::ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node, | |||
| return status; | |||
| } | |||
| // insert tensorliststack after while output | |||
| auto root_anf_graph = root_while_node->func_graph(); | |||
| auto stack_elem_node = CreateConstParamter(root_anf_graph, -1); | |||
| stack_elem_node->set_name(loop_node_name + "_element_shape"); | |||
| for (int j = 0; j < act_outputs_num; j++) { | |||
| auto output_size = onnx_node.output_size(); | |||
| auto &loop_output_name = onnx_node.output(output_size - act_outputs_num + j); | |||
| auto &while_output_node = control_nodes_map_[loop_node_name]->at(loop_output_name); | |||
| auto stack_attr = std::make_unique<schema::TensorListStackT>(); | |||
| if (stack_attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto stack_value_node = CreateValueNode(stack_attr.release(), schema::PrimitiveType_TensorListStack); | |||
| std::vector<AnfNodePtr> stack_inputs = {stack_value_node, while_output_node, stack_elem_node}; | |||
| auto tensorlist_stack_cnode = root_anf_graph->NewCNode(stack_inputs); | |||
| if (tensorlist_stack_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "new cnode error"; | |||
| return RET_ERROR; | |||
| } | |||
| tensorlist_stack_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_stack_node_" + std::to_string(j)); | |||
| tensorlist_stack_cnode->set_abstract(stack_elem_node->abstract()); | |||
| // update getitem value output index | |||
| auto new_get_item_value = NewValueNode(MakeValue<int>(body_graph_inputs.size() - act_outputs_num + i)); | |||
| while_output_node->cast<CNodePtr>()->set_input(2, new_get_item_value); | |||
| // insert tensorliststack after while_output | |||
| (*control_nodes_map_[loop_node_name])[loop_output_name] = tensorlist_stack_cnode; | |||
| status = AddTensorListStackNode(root_while_node, onnx_node, act_outputs_num, body_graph_inputs.size()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "add tensorliststack node failed"; | |||
| return status; | |||
| } | |||
| } | |||
| return_tuple_cnode->set_inputs(return_new_inputs); | |||
| @@ -96,6 +96,8 @@ class OnnxModelParser : public ModelParser { | |||
| STATUS AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs, | |||
| const std::string &loop_node_name, std::vector<AnfNodePtr> *body_graph_inputs, | |||
| int act_output_num); | |||
| STATUS AddTensorListStackNode(const AnfNodePtr &root_while_node, const onnx::NodeProto &onnx_node, int act_output_num, | |||
| int body_output_size); | |||
| STATUS BuildCondGraph(const FuncGraphPtr &cond_graph, const AnfNodePtr &root_while_node, int inputs_num, | |||
| const std::string &cond_graph_name); | |||
| STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph, | |||
| @@ -14,11 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "src/ops/primitive_c.h" | |||
| @@ -266,11 +266,11 @@ STATUS OnnxInputAdjustOpPass::AdjustStridedSlice(const FuncGraphPtr &func_graph, | |||
| auto inputs = cnode->inputs(); | |||
| switch (cnode->inputs().size()) { | |||
| case 4: { | |||
| std::vector<int> axises; | |||
| std::vector<int> axes; | |||
| for (int i = 0; i < size; ++i) { | |||
| axises.push_back(i); | |||
| axes.push_back(i); | |||
| } | |||
| auto new_param_node = BuildParameterNode(func_graph, axises, cnode->fullname_with_scope() + "_axises"); | |||
| auto new_param_node = BuildParameterNode(func_graph, axes, cnode->fullname_with_scope() + "_axes"); | |||
| if (new_param_node == nullptr) { | |||
| MS_LOG(ERROR) << "new a parameter node failed."; | |||
| } | |||
| @@ -327,9 +327,9 @@ STATUS OnnxInputAdjustOpPass::AdjustResize(const CNodePtr &cnode) { | |||
| } | |||
| auto attr = reinterpret_cast<schema::ResizeT *>(value); | |||
| if (cnode->inputs().size() > 3 && | |||
| attr->coordinateTransformMode == schema::CoordinateTransformMode_TF_CROP_AND_RESIZE) { | |||
| attr->coordinateTransformMode != schema::CoordinateTransformMode_TF_CROP_AND_RESIZE) { | |||
| auto new_resize_inputs = cnode->inputs(); | |||
| new_resize_inputs.erase(new_resize_inputs.begin() + 1); | |||
| new_resize_inputs.erase(new_resize_inputs.begin() + 2); | |||
| cnode->set_inputs(new_resize_inputs); | |||
| } | |||
| if (cnode->inputs().size() > 3 && attr->coordinateTransformMode == schema::CoordinateTransformMode_HALF_PIXEL) { | |||
| @@ -597,6 +597,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| status = ReplaceTransposeWithGraphInput(func_graph, cnode); | |||
| } else if (type == schema::PrimitiveType_Resize) { | |||
| status = AdjustResize(cnode); | |||
| } else { | |||
| continue; | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "adjust input pass is failed."; | |||