| @@ -99,6 +99,8 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||||
| {28.0f, 28.0f, 28.0f, 28.0f}, | {28.0f, 28.0f, 28.0f, 28.0f}, | ||||
| {3150.0f, 3150.0f, 3150.0f, 3150.0f}, | {3150.0f, 3150.0f, 3150.0f, 3150.0f}, | ||||
| {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; | {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; | ||||
| float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; | |||||
| float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; | |||||
| int count = (ele_num / C4NUM) * C4NUM; | int count = (ele_num / C4NUM) * C4NUM; | ||||
| for (; i < count; i += C4NUM) { | for (; i < count; i += C4NUM) { | ||||
| float32x4_t input = vcvt_f32_f16(vld1_f16(src + i)); | float32x4_t input = vcvt_f32_f16(vld1_f16(src + i)); | ||||
| @@ -109,7 +111,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||||
| float32x4_t b = vaddq_f32( | float32x4_t b = vaddq_f32( | ||||
| vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | ||||
| paramv[2]); | paramv[2]); | ||||
| vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(a, b))); | |||||
| vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one))); | |||||
| } | } | ||||
| #endif | #endif | ||||
| for (; i < ele_num; ++i) { | for (; i < ele_num; ++i) { | ||||
| @@ -118,6 +120,8 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||||
| float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; | float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; | ||||
| float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; | float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; | ||||
| dst[i] = a / b; | dst[i] = a / b; | ||||
| dst[i] = MSMAX(dst[i], -1); | |||||
| dst[i] = MSMIN(dst[i], 1); | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -109,6 +109,8 @@ int Tanh(const float *src, int length, float *dst) { | |||||
| {28.0f, 28.0f, 28.0f, 28.0f}, | {28.0f, 28.0f, 28.0f, 28.0f}, | ||||
| {3150.0f, 3150.0f, 3150.0f, 3150.0f}, | {3150.0f, 3150.0f, 3150.0f, 3150.0f}, | ||||
| {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; | {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; | ||||
| float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; | |||||
| float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; | |||||
| int count = (length / C4NUM) * C4NUM; | int count = (length / C4NUM) * C4NUM; | ||||
| for (; i < count; i += C4NUM) { | for (; i < count; i += C4NUM) { | ||||
| float32x4_t input = vld1q_f32(src + i); | float32x4_t input = vld1q_f32(src + i); | ||||
| @@ -119,7 +121,7 @@ int Tanh(const float *src, int length, float *dst) { | |||||
| float32x4_t b = vaddq_f32( | float32x4_t b = vaddq_f32( | ||||
| vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | ||||
| paramv[2]); | paramv[2]); | ||||
| vst1q_f32(dst + i, vdivq_f32(a, b)); | |||||
| vst1q_f32(dst + i, vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one)); | |||||
| } | } | ||||
| #endif | #endif | ||||
| for (; i < length; ++i) { | for (; i < length; ++i) { | ||||
| @@ -128,6 +130,8 @@ int Tanh(const float *src, int length, float *dst) { | |||||
| float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; | float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; | ||||
| float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; | float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; | ||||
| dst[i] = a / b; | dst[i] = a / b; | ||||
| dst[i] = MSMAX(dst[i], -1); | |||||
| dst[i] = MSMIN(dst[i], 1); | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -81,6 +81,43 @@ int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_da | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int IntReduceSum(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, | |||||
| int thread_num) { | |||||
| if (src_data == NULL || dst_data == NULL) { | |||||
| return NNACL_NULL_PTR; | |||||
| } | |||||
| int i, j; | |||||
| #ifdef ENABLE_NEON | |||||
| int block_mod = inner_size % C4NUM; | |||||
| int block_c4 = inner_size - block_mod; | |||||
| #endif | |||||
| for (j = tid; j < outer_size; j += thread_num) { | |||||
| const int *outer_src = src_data + j * axis_size * inner_size; | |||||
| int *outer_dst = dst_data + j * inner_size; | |||||
| int k = 0; | |||||
| #ifdef ENABLE_NEON | |||||
| for (; k < block_c4; k += C4NUM) { | |||||
| const int *inner_src = outer_src + k; | |||||
| int *inner_dst = outer_dst + k; | |||||
| int32x4_t tmp = {0, 0, 0, 0}; | |||||
| for (i = 0; i < axis_size; i++) { | |||||
| tmp = vaddq_s32(tmp, vld1q_s32(inner_src + i * inner_size)); | |||||
| } | |||||
| vst1q_s32(inner_dst, tmp); | |||||
| } | |||||
| #endif | |||||
| for (; k < inner_size; k++) { | |||||
| const int *inner_src = outer_src + k; | |||||
| int *inner_dst = outer_dst + k; | |||||
| int tmp = 0; | |||||
| for (i = 0; i < axis_size; i++) { | |||||
| tmp += inner_src[i * inner_size]; | |||||
| } | |||||
| *inner_dst = tmp; | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, | int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, | ||||
| int thread_num) { | int thread_num) { | ||||
| if (src_data == NULL || dst_data == NULL) { | if (src_data == NULL || dst_data == NULL) { | ||||
| @@ -26,6 +26,8 @@ int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_d | |||||
| int thread_num); | int thread_num); | ||||
| int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, | int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, | ||||
| int thread_num); | int thread_num); | ||||
| int IntReduceSum(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, | |||||
| int thread_num); | |||||
| int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, | int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, | ||||
| int thread_num); | int thread_num); | ||||
| int IntReduceMax(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, | int IntReduceMax(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, | ||||
| @@ -50,6 +50,7 @@ int ReduceCPUKernel::Init() { | |||||
| switch (mode_) { | switch (mode_) { | ||||
| case static_cast<int>(ReduceMode_ReduceSum): { | case static_cast<int>(ReduceMode_ReduceSum): { | ||||
| reducer_ = ReduceSum; | reducer_ = ReduceSum; | ||||
| int_reducer_ = IntReduceSum; | |||||
| break; | break; | ||||
| } | } | ||||
| case static_cast<int>(ReduceMode_ReduceMean): { | case static_cast<int>(ReduceMode_ReduceMean): { | ||||
| @@ -23,10 +23,6 @@ using mindspore::schema::PrimitiveType_Conv2D; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) { | ||||
| if (conv_param_->group_ != 1) { | |||||
| MS_LOG(WARNING) << "Only support group equals 1 for npu convolution op"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -56,10 +56,11 @@ int PoolingNPUKernel::SetPoolingParam() { | |||||
| if (pooling_param_->round_mode_ == RoundMode_Floor) { // no use in cpu | if (pooling_param_->round_mode_ == RoundMode_Floor) { // no use in cpu | ||||
| pooling_->set_attr_ceil_mode(0); | pooling_->set_attr_ceil_mode(0); | ||||
| pooling_->set_attr_data_mode(1); | |||||
| } else { | } else { | ||||
| pooling_->set_attr_ceil_mode(1); | pooling_->set_attr_ceil_mode(1); | ||||
| pooling_->set_attr_data_mode(0); | |||||
| } | } | ||||
| // todo data mode | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -73,3 +73,4 @@ ml_video_edit_video_segment_gauss_adaptis_part1 | |||||
| ml_video_edit_Mnet | ml_video_edit_Mnet | ||||
| ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 | ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 | ||||
| ml_video_edit_person_divison_video | ml_video_edit_person_divison_video | ||||
| hdc_Face_Aesthetic_MTI_Aesthetic | |||||
| @@ -41,3 +41,12 @@ ml_video_edit_img_segment_adaptise.pb 0.5 2 | |||||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb 3 2 | ml_video_edit_video_segment_gauss_adaptis_part2.pb 3 2 | ||||
| ml_video_edit_person_divison_pic 8 2 | ml_video_edit_person_divison_pic 8 2 | ||||
| ml_video_edit_person_divison_video 0.5 | ml_video_edit_person_divison_video 0.5 | ||||
| ml_video_edit_imitate_filter.onnx 230 | |||||
| ml_video_edit_judge.onnx 5 | |||||
| ml_video_edit_vignet.onnx 0.5 | |||||
| hdc_Face_Aesthetic_MTI_Aesthetic 0.5 | |||||
| hdc_Face_Emotion_MTI_Aesthetic.onnx 30 | |||||
| hdc_Face_Landmark5_MTI_Aesthetic.onnx 0.5 | |||||
| hdc_Image_Aesthetic_MTI_Aesthetic.onnx 0.5 | |||||
| hdc_mobilenet_1w_class.onnx 10 | |||||
| hdc_resnet_1w_class.onnx 5 | |||||
| @@ -45,3 +45,9 @@ ml_video_edit_style_transfer_starry.onnx | |||||
| ml_video_edit_judge.onnx | ml_video_edit_judge.onnx | ||||
| ml_video_edit_vignet.onnx | ml_video_edit_vignet.onnx | ||||
| ssd_mobilenet_v1_10.onnx;1,383,640,3 | ssd_mobilenet_v1_10.onnx;1,383,640,3 | ||||
| hdc_Face_Emotion_MTI_Aesthetic.onnx | |||||
| hdc_Face_Landmark5_MTI_Aesthetic.onnx | |||||
| hdc_Image_Aesthetic_MTI_Aesthetic.onnx | |||||
| hdc_mobilenet_1w_class.onnx | |||||
| hdc_resnet_1w_class.onnx | |||||
| ml_video_edit_imitate_filter.onnx | |||||
| @@ -26,7 +26,7 @@ crnn_lite_lstm_v2.onnx;32,32,32,1 0.3 | |||||
| psenet_lite_mbv2.onnx;1,32,32,3 0.6 | psenet_lite_mbv2.onnx;1,32,32,3 0.6 | ||||
| super-resolution-10.onnx;1,224,224,1 4.5 | super-resolution-10.onnx;1,224,224,1 4.5 | ||||
| tinyyolov2-8.onnx;1,416,416,3 5.5 | tinyyolov2-8.onnx;1,416,416,3 5.5 | ||||
| ml_2012_ocr_cn.onnx 200 | |||||
| ml_2012_ocr_cn.onnx -1 | |||||
| #ml_2012_ocr_cn_noLSTM.onnx 1 | #ml_2012_ocr_cn_noLSTM.onnx 1 | ||||
| candy-9.onnx 5 | candy-9.onnx 5 | ||||
| mosaic-9.onnx 4 | mosaic-9.onnx 4 | ||||
| @@ -10,3 +10,4 @@ ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 | |||||
| decoder.onnx;2;1,7,512:1,7 | decoder.onnx;2;1,7,512:1,7 | ||||
| fasterrcnn_crop.pb;1;420,630,3 | fasterrcnn_crop.pb;1;420,630,3 | ||||
| ml_video_edit_person_divison_pic;2 | ml_video_edit_person_divison_pic;2 | ||||
| hdc_tb_cn_neg.tflite;3 | |||||
| @@ -1547,8 +1547,11 @@ function Run_arm64() { | |||||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | ||||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt | echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt | ||||
| echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt | |||||
| if [[ $accuracy_limit == "-1" ]]; then | |||||
| echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt | |||||
| else | |||||
| echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt | |||||
| fi | |||||
| cat adb_run_cmd.txt >> "${run_arm64_log_file}" | cat adb_run_cmd.txt >> "${run_arm64_log_file}" | ||||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | ||||
| if [ $? = 0 ]; then | if [ $? = 0 ]; then | ||||
| @@ -19,8 +19,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeConvolutionParser::ParseGroupConvolution(schema::PrimitiveT *primitiveT, schema::Conv2DT *attr) { | |||||
| if (attr->group == 1) { | |||||
| STATUS CaffeConvolutionParser::ParseDepthwiseConvolution(schema::PrimitiveT *primitiveT, schema::Conv2DT *attr) { | |||||
| if (attr->group == 1 || attr->group != attr->channelOut) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>(); | std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>(); | ||||
| @@ -125,9 +125,9 @@ PrimitiveC *CaffeConvolutionParser::ParseLitePrimitive(const caffe::LayerParamet | |||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | primitive->value.type = schema::PrimitiveType_Conv2D; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| status = ParseGroupConvolution(primitive.get(), static_cast<schema::Conv2DT *>(primitive->value.value)); | |||||
| status = ParseDepthwiseConvolution(primitive.get(), static_cast<schema::Conv2DT *>(primitive->value.value)); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Parse group convolution failed"; | |||||
| MS_LOG(ERROR) << "Parse depthwise convolution failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -32,7 +32,7 @@ class CaffeConvolutionParser : public CaffeNodeParser { | |||||
| PrimitiveC *ParseLitePrimitive(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) override; | PrimitiveC *ParseLitePrimitive(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight) override; | ||||
| private: | private: | ||||
| static STATUS ParseGroupConvolution(schema::PrimitiveT *primitiveT, schema::Conv2DT *attr); | |||||
| static STATUS ParseDepthwiseConvolution(schema::PrimitiveT *primitiveT, schema::Conv2DT *attr); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||