diff --git a/mindspore/lite/nnacl/fp16/softmax_fp16.c b/mindspore/lite/nnacl/fp16/softmax_fp16.c index b0df45db6b..257ecf255b 100644 --- a/mindspore/lite/nnacl/fp16/softmax_fp16.c +++ b/mindspore/lite/nnacl/fp16/softmax_fp16.c @@ -20,39 +20,35 @@ // output = exp(input) / reduce_sum(exp(input), axis) void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter) { - int32_t axis = parameter->axis_; + int axis = parameter->axis_; int n_dim = parameter->n_dim_; - int ele_size = parameter->element_size_; int *input_shape = parameter->input_shape_; + int inner_size = 1; + int outter_size = 1; - float16_t max_data = input_ptr[0]; - for (int i = 0; i < ele_size; i++) { - max_data = max_data > input_ptr[i] ? max_data : input_ptr[i]; - } - - for (int i = 0; i < ele_size; i++) { - output_ptr[i] = exp(input_ptr[i] - max_data); - } - int inner_size = 1, outter_size = 1; for (int i = 0; i < axis; i++) { outter_size *= input_shape[i]; } for (int i = axis + 1; i < n_dim; i++) { inner_size *= input_shape[i]; } - for (int i = 0; i < outter_size; i++) { int outter_offset = i * input_shape[axis] * inner_size; int sum_outter_offset = i * inner_size; for (int k = 0; k < inner_size; k++) { int inner_offset = outter_offset + k; + float16_t max_data = input_ptr[inner_offset]; for (int j = 0; j < input_shape[axis]; j++) { int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; } } } - for (int i = 0; i < outter_size; i++) { int outter_offset = i * input_shape[axis] * inner_size; int sum_outter_offset = i * inner_size; diff --git a/mindspore/lite/nnacl/fp32/softmax.c b/mindspore/lite/nnacl/fp32/softmax.c index 2f79552399..484f9777bf 100644 --- a/mindspore/lite/nnacl/fp32/softmax.c +++ b/mindspore/lite/nnacl/fp32/softmax.c @@ -20,39 +20,35 @@ // output = exp(input) / reduce_sum(exp(input), axis) void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter) { - int32_t axis = parameter->axis_; + int axis = parameter->axis_; int n_dim = parameter->n_dim_; - int ele_size = parameter->element_size_; int *input_shape = parameter->input_shape_; + int inner_size = 1; + int outter_size = 1; - float max_data = -FLT_MAX; - for (int i = 0; i < ele_size; i++) { - max_data = max_data > input_ptr[i] ? max_data : input_ptr[i]; - } - - for (int i = 0; i < ele_size; i++) { - output_ptr[i] = exp(input_ptr[i] - max_data); - } - int inner_size = 1, outter_size = 1; for (int i = 0; i < axis; i++) { outter_size *= input_shape[i]; } for (int i = axis + 1; i < n_dim; i++) { inner_size *= input_shape[i]; } - for (int i = 0; i < outter_size; i++) { int outter_offset = i * input_shape[axis] * inner_size; int sum_outter_offset = i * inner_size; for (int k = 0; k < inner_size; k++) { int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; for (int j = 0; j < input_shape[axis]; j++) { int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; } } } - for (int i = 0; i < outter_size; i++) { int outter_offset = i * input_shape[axis] * inner_size; int sum_outter_offset = i * inner_size;