From: @YeFeng_24 Reviewed-by: @hangangqiang Signed-off-by:pull/13946/MERGE
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.h" | |||
| #include "src/runtime/kernel/arm/base/tensorlist_fromtensor.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| @@ -52,14 +52,7 @@ int TensorListFromTensorCPUKernel::IsCompatibleShape() { | |||
| return RET_OK; | |||
| } | |||
| int TensorListFromTensorCPUKernel::Init() { | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int TensorListFromTensorCPUKernel::Init() { return RET_OK; } | |||
| int TensorListFromTensorCPUKernel::ReSize() { return RET_OK; } | |||
| @@ -71,6 +64,7 @@ int TensorListFromTensorCPUKernel::Run() { | |||
| MS_LOG(ERROR) << "IsNotCompatibleShape!"; | |||
| return RET_ERROR; | |||
| } | |||
| dtype_ = in_tensors_[0]->data_type(); | |||
| if (input0_->shape().size() == 0) { | |||
| MS_LOG(ERROR) << "input0_->shape().size():" << input0_->shape().size() << " must be greater than 0"; | |||
| } | |||
| @@ -97,9 +91,10 @@ int TensorListFromTensorCPUKernel::Run() { | |||
| << " must be euqal to devision_dim0:" << devision_dim0; | |||
| return RET_ERROR; | |||
| } | |||
| auto out_data = out_ptr->MutableData(); | |||
| auto out_data = out_ptr->data_c(); | |||
| MS_ASSERT(out_data != nullptr); | |||
| memcpy(out_data, in_data, data_offset); | |||
| out_ptr->set_data_type(dtype_); | |||
| in_data += data_offset; | |||
| } | |||
| output0->set_tensors_data_type(dtype_); | |||
| @@ -16,7 +16,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "include/ms_tensor.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.h" | |||
| #include "src/runtime/kernel/arm/base/tensorlist_getitem.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| @@ -28,16 +28,7 @@ using mindspore::schema::PrimitiveType_TensorListGetItem; | |||
| namespace mindspore::kernel { | |||
| int TensorListGetItemCPUKernel::Init() { | |||
| MS_ASSERT(in_tensors_.size() >= 2); | |||
| MS_ASSERT(in_tensors_.at(0) != nullptr); | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int TensorListGetItemCPUKernel::Init() { return RET_OK; } | |||
| int TensorListGetItemCPUKernel::Run() { | |||
| MS_ASSERT(in_tensors_.size() >= 2); | |||
| @@ -48,10 +39,7 @@ int TensorListGetItemCPUKernel::Run() { | |||
| if (input0->root_tensor() != nullptr) { | |||
| input0 = reinterpret_cast<lite::TensorList *>(input0->root_tensor()); | |||
| } | |||
| if (dtype_ != input0->tensors_data_type()) { | |||
| MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| dtype_ = input0->tensors_data_type(); | |||
| MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr); | |||
| index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0]; | |||
| int dim0 = input0->ElementsNum() - 1; | |||
| @@ -16,7 +16,7 @@ | |||
| #include <vector> | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.h" | |||
| #include "src/runtime/kernel/arm/base/tensorlist_reserve.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -27,14 +27,7 @@ using mindspore::schema::PrimitiveType_TensorListReserve; | |||
| namespace mindspore::kernel { | |||
| int TensorListReserveCPUKernel::Init() { | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && element_dtype_ == kNumberTypeFloat32) { | |||
| element_dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int TensorListReserveCPUKernel::Init() { return RET_OK; } | |||
| int TensorListReserveCPUKernel::Run() { | |||
| auto input0 = in_tensors_.at(0); | |||
| @@ -16,7 +16,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "include/ms_tensor.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.h" | |||
| #include "src/runtime/kernel/arm/base/tensorlist_setitem.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| @@ -28,21 +28,9 @@ using mindspore::schema::PrimitiveType_TensorListSetItem; | |||
| namespace mindspore::kernel { | |||
| int TensorListSetItemCPUKernel::Init() { | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int TensorListSetItemCPUKernel::Init() { return RET_OK; } | |||
| int TensorListSetItemCPUKernel::CheckParam() { | |||
| if (dtype_ != kTypeUnknown && input0_->tensors_data_type() != kTypeUnknown && | |||
| dtype_ != input0_->tensors_data_type()) { | |||
| MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->tensors_data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; | |||
| return RET_ERROR; | |||
| @@ -70,7 +58,6 @@ int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) { | |||
| int TensorListSetItemCPUKernel::Run() { | |||
| input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]); | |||
| if (CheckParam() != RET_OK) { | |||
| MS_LOG(ERROR) << "check param failed."; | |||
| return RET_ERROR; | |||
| @@ -28,8 +28,7 @@ class TensorListSetItemCPUKernel : public LiteKernel { | |||
| public: | |||
| TensorListSetItemCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||
| : LiteKernel(parameter, inputs, outputs, ctx), | |||
| dtype_(static_cast<TypeId>(reinterpret_cast<TensorListParameter *>(parameter)->element_dtype_)) {} | |||
| : LiteKernel(parameter, inputs, outputs, ctx) {} | |||
| ~TensorListSetItemCPUKernel() = default; | |||
| int Init() override; | |||
| @@ -43,7 +42,6 @@ class TensorListSetItemCPUKernel : public LiteKernel { | |||
| lite::Tensor *input2_ = nullptr; | |||
| lite::TensorList *output0_ = nullptr; | |||
| int index_ = 0; | |||
| TypeId dtype_ = kTypeUnknown; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -19,7 +19,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.h" | |||
| #include "src/runtime/kernel/arm/base/tensorlist_stack.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -31,11 +31,6 @@ using mindspore::schema::PrimitiveType_TensorListStack; | |||
| namespace mindspore::kernel { | |||
| int TensorListStackCPUKernel::CheckParam() { | |||
| if (dtype_ != kTypeUnknown && input0_->tensors_data_type() != dtype_) { | |||
| MS_LOG(ERROR) << "in_tensors_[0].tensors_data_type:[" << input0_->tensors_data_type() << "] must be equal " | |||
| << "param.data_type:[" << dtype_ << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| if (num_element_ != -1 && input0_->ElementsNum() != num_element_) { | |||
| MS_LOG(ERROR) << "in_tensors_[0].ElementsNum():[" << input0_->ElementsNum() << "] must be equal " | |||
| << "param.elements_num:[" << num_element_ << "]"; | |||
| @@ -60,11 +55,6 @@ int TensorListStackCPUKernel::Init() { | |||
| MS_ASSERT(input0_ != nullptr); | |||
| output0_ = out_tensors_[0]; | |||
| MS_ASSERT(output0_ != nullptr); | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| @@ -146,18 +136,11 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) { | |||
| } | |||
| int TensorListStackCPUKernel::Run() { | |||
| if (dtype_ == kTypeUnknown) { | |||
| dtype_ = input0_->tensors_data_type(); | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| } | |||
| if (CheckParam() != RET_OK) { | |||
| MS_LOG(ERROR) << "CheckParam failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| dtype_ = input0_->tensors_data_type(); | |||
| if (output0_->ElementsNum() == 0) { | |||
| return RET_OK; | |||
| } | |||
| @@ -76,7 +76,11 @@ int ArithmeticSelfFp16CPUKernel::Run() { | |||
| auto input_tensor = in_tensors_.at(0); | |||
| auto output_tensor = out_tensors_.at(0); | |||
| input_fp16_ptr_ = reinterpret_cast<float16_t *>(input_tensor->data_c()); | |||
| if (input_tensor->data_type() == kNumberTypeFloat32) { | |||
| input_fp16_ptr_ = ConvertInputFp32toFp16(input_tensor, context_); | |||
| } else { | |||
| input_fp16_ptr_ = reinterpret_cast<float16_t *>(input_tensor->data_c()); | |||
| } | |||
| output_fp16_ptr_ = reinterpret_cast<float16_t *>(output_tensor->data_c()); | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfRun, this, op_parameter_->thread_num_); | |||
| @@ -199,7 +199,8 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter | |||
| } | |||
| kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors, | |||
| const std::vector<Tensor *> &out_tensors, const Model::Node *node) { | |||
| const std::vector<Tensor *> &out_tensors, const Model::Node *node, | |||
| TypeId prefer_data_type) { | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | |||
| OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)]; | |||
| @@ -272,7 +273,8 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||
| } | |||
| } | |||
| #endif | |||
| if (mindspore::lite::IsSupportFloat16() && | |||
| if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) && | |||
| mindspore::lite::IsSupportFloat16() && | |||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | |||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | |||
| auto tensor_origin_data_map = | |||
| @@ -301,15 +303,17 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | |||
| desc.data_type = kNumberTypeFloat32; | |||
| } | |||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(op_parameter, in_tensors, desc.data_type, need_restore); | |||
| auto ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, op_parameter, &kernel); | |||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||
| if (ret == RET_OK) { | |||
| return kernel; | |||
| } else if (ret == RET_ERROR) { | |||
| ret = InferNodeShape(node, &infer_shape_interrupt); | |||
| if (!(ret == RET_INFER_INVALID || ret == RET_OK)) { | |||
| MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; | |||
| if (prefer_data_type == kNumberTypeFloat32 || prefer_data_type == kTypeUnknown) { | |||
| auto tensor_origin_data_map = DequantUtil::DequantTensor(op_parameter, in_tensors, desc.data_type, need_restore); | |||
| auto ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, desc, op_parameter, &kernel); | |||
| DequantUtil::RestoreTensorData(tensor_origin_data_map); | |||
| if (ret == RET_OK) { | |||
| return kernel; | |||
| } else if (ret == RET_ERROR) { | |||
| ret = InferNodeShape(node, &infer_shape_interrupt); | |||
| if (!(ret == RET_INFER_INVALID || ret == RET_OK)) { | |||
| MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; | |||
| } | |||
| } | |||
| } | |||
| return nullptr; | |||
| @@ -327,7 +331,7 @@ kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node * | |||
| std::vector<kernel::LiteKernel *> sub_kernels; | |||
| std::vector<lite::Tensor *> in_tensors; | |||
| std::vector<lite::Tensor *> out_tensors; | |||
| auto ret = ScheduleSubGraphToKernels(sub_graph_index, &sub_kernels, &in_tensors, &out_tensors); | |||
| auto ret = ScheduleSubGraphToKernels(sub_graph_index, &sub_kernels, &in_tensors, &out_tensors, kNumberTypeFloat32); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_; | |||
| return nullptr; | |||
| @@ -338,11 +342,11 @@ kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node * | |||
| return subgraph; | |||
| } | |||
| kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src_node) { | |||
| kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src_node, TypeId prefer_data_type) { | |||
| std::vector<Tensor *> inputs; | |||
| std::vector<Tensor *> outputs; | |||
| FindNodeInoutTensors(*src_node, &inputs, &outputs); | |||
| auto *kernel = this->FindBackendKernel(inputs, outputs, src_node); | |||
| auto *kernel = this->FindBackendKernel(inputs, outputs, src_node, prefer_data_type); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_ | |||
| << ", type: " << PrimitiveTypeName(GetPrimitiveType(src_node->primitive_)); | |||
| @@ -355,7 +359,7 @@ kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src | |||
| int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels, | |||
| std::vector<lite::Tensor *> *in_tensors, | |||
| std::vector<lite::Tensor *> *out_tensors) { | |||
| std::vector<lite::Tensor *> *out_tensors, TypeId prefer_data_type) { | |||
| MS_ASSERT(src_model_ != nullptr); | |||
| MS_ASSERT(!src_model_->sub_graphs_.empty()); | |||
| MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index); | |||
| @@ -372,7 +376,7 @@ int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kern | |||
| if (IsPartialNode(primitive)) { // sub_graph | |||
| kernel = SchedulePartialToKernel(node); | |||
| } else { // kernel | |||
| kernel = ScheduleNodeToKernel(node); | |||
| kernel = ScheduleNodeToKernel(node, prefer_data_type); | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << node->name_ | |||
| @@ -57,14 +57,16 @@ class Scheduler { | |||
| // schedule a node to kernel according to context and kernels registered | |||
| kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors, | |||
| const std::vector<Tensor *> &out_tensors, const Model::Node *node); | |||
| const std::vector<Tensor *> &out_tensors, const Model::Node *node, | |||
| TypeId prefer_data_type = kTypeUnknown); | |||
| // schedule a partial node to a subgraph_kernel | |||
| kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); | |||
| // schedule a node to a kernel | |||
| kernel::LiteKernel *ScheduleNodeToKernel(const lite::Model::Node *src_node); | |||
| kernel::LiteKernel *ScheduleNodeToKernel(const lite::Model::Node *src_node, TypeId prefer_data_type = kTypeUnknown); | |||
| // schedule a Model::SubGraph into a vector of kernel and subgraph_kernel | |||
| int ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels, | |||
| std::vector<lite::Tensor *> *in_tensors, std::vector<lite::Tensor *> *out_tensors); | |||
| std::vector<lite::Tensor *> *in_tensors, std::vector<lite::Tensor *> *out_tensors, | |||
| TypeId prefer_data_type = kTypeUnknown); | |||
| // find in_kernels_ and out_kernels of kernel, sub_graph and nodes_ in sub_graph | |||
| static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||
| @@ -366,7 +366,8 @@ int CpuFp16SubGraph::PostProcess() { | |||
| } | |||
| } | |||
| } | |||
| if (real_tensor->data_type() == kNumberTypeFloat16 && origin_input_data_.at(real_tensor) != nullptr) { | |||
| if (real_tensor->data_type() == kNumberTypeFloat16 && | |||
| origin_input_data_.find(real_tensor) != origin_input_data_.end()) { | |||
| auto origin_tensor_data = origin_input_data_.at(real_tensor); | |||
| real_tensor->FreeData(); | |||
| MS_ASSERT(origin_tensor_data->data_ != nullptr); | |||
| @@ -0,0 +1 @@ | |||
| decoder_step_201217_modified.pb 5 | |||
| @@ -214,6 +214,21 @@ function Run_Converter() { | |||
| fi | |||
| done < ${models_tflite_fp16_config} | |||
| while read line; do | |||
| fp16_line_info=${line} | |||
| if [[ $fp16_line_info == \#* ]]; then | |||
| continue | |||
| fi | |||
| model_name=`echo ${fp16_line_info}|awk -F ' ' '{print $1}'` | |||
| echo 'cp '${ms_models_path}'/'${model_name}'.ms' ${ms_models_path}'/'${model_name}'.fp16.ms' | |||
| cp ${ms_models_path}/${model_name}.ms ${ms_models_path}/${model_name}.fp16.ms | |||
| if [ $? = 0 ]; then | |||
| converter_result='converter fp16 '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | |||
| else | |||
| converter_result='converter fp16 '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file};return 1 | |||
| fi | |||
| done < ${models_tf_fp16_config} | |||
| # Convert tflite weightquant models: | |||
| while read line; do | |||
| weight_quant_line_info=${line} | |||
| @@ -1832,6 +1847,34 @@ function Run_arm64_fp16() { | |||
| run_result='arm64_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| done < ${models_multiple_inputs_fp16_config} | |||
| # Run tf fp16 models | |||
| while read line; do | |||
| model_name_and_input_num=${line%;*} | |||
| length=${#model_name_and_input_num} | |||
| input_shapes=${line:length+1} | |||
| tf_line_info=${model_name_and_input_num} | |||
| if [[ $model_name == \#* ]]; then | |||
| continue | |||
| fi | |||
| model_name=`echo ${tf_line_info}|awk -F ' ' '{print $1}'` | |||
| input_num=`echo ${tf_line_info}|awk -F ' ' '{print $2}'` | |||
| input_files='' | |||
| for i in $(seq 1 $input_num) | |||
| do | |||
| input_files=$input_files'/data/local/tmp/input_output/input/'$model_name'.ms_'$i'.bin,' | |||
| done | |||
| echo ${model_name} >> "${run_arm64_fp16_log_file}" | |||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --inputShapes='${input_shapes}' --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true' >> "${run_arm64_fp16_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --inputShapes='${input_shapes}' --modelFile='${model_name}'.ms --inDataFile='${input_files}' --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true' >> adb_run_cmd.txt | |||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_fp16_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='arm64_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| run_result='arm64_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| done < ${models_tf_fp16_config} | |||
| } | |||
| # Run on gpu platform: | |||
| function Run_gpu() { | |||
| @@ -2069,6 +2112,7 @@ models_npu_config=${basepath}/models_npu.cfg | |||
| models_compatibility_config=${basepath}/models_compatibility.cfg | |||
| models_with_multiple_inputs_config=${basepath}/models_with_multiple_inputs.cfg | |||
| models_for_process_only_config=${basepath}/models_for_process_only.cfg | |||
| models_tf_fp16_config=${basepath}/models_tf_fp16.cfg | |||
| ms_models_path=${basepath}/ms_models | |||