Merge pull request !4062 from yangruoqi713/fix_bugtags/v0.7.0-beta
| @@ -376,7 +376,7 @@ table BNGradInput { | |||
| channels: int; | |||
| } | |||
| table Scale { | |||
| format: Format = 0; | |||
| axis: int; | |||
| } | |||
| table Eltwise { | |||
| @@ -28,12 +28,16 @@ int Nchw2Nhwc::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| std::vector<int> nchw_shape = input->shape(); | |||
| std::vector<int> nhwc_shape{nchw_shape}; | |||
| nhwc_shape[NHWC_N] = nchw_shape[NCHW_N]; | |||
| nhwc_shape[NHWC_H] = nchw_shape[NCHW_H]; | |||
| nhwc_shape[NHWC_W] = nchw_shape[NCHW_W]; | |||
| nhwc_shape[NHWC_C] = nchw_shape[NCHW_C]; | |||
| output->set_shape(nhwc_shape); | |||
| if (nchw_shape.size() != 4) { | |||
| output->set_shape(nchw_shape); | |||
| } else { | |||
| std::vector<int> nhwc_shape{nchw_shape}; | |||
| nhwc_shape[NHWC_N] = nchw_shape[NCHW_N]; | |||
| nhwc_shape[NHWC_H] = nchw_shape[NCHW_H]; | |||
| nhwc_shape[NHWC_W] = nchw_shape[NCHW_W]; | |||
| nhwc_shape[NHWC_C] = nchw_shape[NCHW_C]; | |||
| output->set_shape(nhwc_shape); | |||
| } | |||
| output->SetFormat(schema::Format_NHWC); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| @@ -28,15 +28,18 @@ int Nhwc2Nchw::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| std::vector<int> nhwc_shape = input->shape(); | |||
| std::vector<int> nchw_shape{nhwc_shape}; | |||
| nchw_shape[NCHW_N] = nhwc_shape[NHWC_N]; | |||
| nchw_shape[NCHW_C] = nhwc_shape[NHWC_C]; | |||
| nchw_shape[NCHW_H] = nhwc_shape[NHWC_H]; | |||
| nchw_shape[NCHW_W] = nhwc_shape[NHWC_W]; | |||
| output->set_shape(nchw_shape); | |||
| if (nhwc_shape.size() != 4) { | |||
| output->set_shape(nhwc_shape); | |||
| } else { | |||
| std::vector<int> nchw_shape{nhwc_shape}; | |||
| nchw_shape[NCHW_N] = nhwc_shape[NHWC_N]; | |||
| nchw_shape[NCHW_C] = nhwc_shape[NHWC_C]; | |||
| nchw_shape[NCHW_H] = nhwc_shape[NHWC_H]; | |||
| nchw_shape[NCHW_W] = nhwc_shape[NHWC_W]; | |||
| output->set_shape(nchw_shape); | |||
| } | |||
| output->SetFormat(schema::Format_NCHW); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -753,15 +753,7 @@ OpParameter *PopulateScaleParameter(const lite::Primitive *primitive) { | |||
| MS_LOG(ERROR) << "value_as_Scale return nullptr"; | |||
| return nullptr; | |||
| } | |||
| // NCHW todo use enum | |||
| if (param->format() == schema::Format_NCHW) { | |||
| scale_param->axis_ = 1; | |||
| scale_param->num_axis_ = 1; | |||
| } else if (param->format() == schema::Format_NHWC) { | |||
| scale_param->axis_ = 3; | |||
| scale_param->num_axis_ = 1; | |||
| } | |||
| scale_param->axis_ = param->axis(); | |||
| return reinterpret_cast<OpParameter *>(scale_param); | |||
| } | |||
| @@ -278,7 +278,7 @@ int Convolution3x3FP16CPUKernel::Run() { | |||
| auto out_tensor = outputs_.at(kOutputIndex); | |||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||
| for (int j = 0; j < out_tensor->ElementsNum(); ++j) { | |||
| output_addr[j] = (float)fp16_out_[j]; | |||
| output_addr[j] = (reinterpret_cast<float *>(fp16_out_))[j]; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -29,85 +29,91 @@ using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Scale; | |||
| namespace mindspore::kernel { | |||
| namespace { | |||
| constexpr int kScaleInputNum = 1; | |||
| constexpr int kScaleOutputNum = 1; | |||
| } // namespace | |||
| int ScaleCPUKernel::Init() { | |||
| int ScaleCPUKernel::InitScaleOffset() { | |||
| auto param = reinterpret_cast<ScaleParameter *>(opParameter); | |||
| auto in_tensor = inputs_.front(); | |||
| auto scale = inputs_.at(1); | |||
| if (inputs_.size() < 2 || inputs_.size() > 3) { | |||
| MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << inputs_.size() << " is given."; | |||
| return RET_ERROR; | |||
| auto scale_tensor = inputs_.at(1); | |||
| float *scale_ptr = reinterpret_cast<float *>(inputs_.at(1)->Data()); | |||
| if (scale_ptr != nullptr) { | |||
| scale_ = reinterpret_cast<float *>(malloc(scale_tensor->ElementsNum() * sizeof(float))); | |||
| if (scale_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(float)); | |||
| } else { | |||
| scale_ = nullptr; | |||
| } | |||
| if (param->axis_ < 0) { | |||
| MS_LOG(ERROR) << "axis illegal."; | |||
| return RET_ERROR; | |||
| if (inputs_.size() == 3) { | |||
| auto offset_tensor = inputs_.at(1); | |||
| offset_ = reinterpret_cast<float *>(malloc(offset_tensor->ElementsNum() * sizeof(float))); | |||
| if (offset_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| param->has_offset_ = true; | |||
| } else { | |||
| offset_ = nullptr; | |||
| param->has_offset_ = false; | |||
| } | |||
| if (param->num_axis_ < 1 || param->num_axis_ + param->axis_ >= in_tensor->shape().size()) { | |||
| MS_LOG(ERROR) << "number of axis illegal"; | |||
| return RET_OK; | |||
| } | |||
| int ScaleCPUKernel::InitParameter() { | |||
| auto param = reinterpret_cast<ScaleParameter *>(opParameter); | |||
| auto in_tensor = inputs_.at(0); | |||
| auto in_shape = in_tensor->shape(); | |||
| auto scale_tensor = inputs_.at(1); | |||
| auto scale_shape = scale_tensor->shape(); | |||
| if (scale_shape.size() + param->axis_ > in_shape.size()) { | |||
| MS_LOG(ERROR) << "Scale tensor shape is incorrect."; | |||
| return RET_ERROR; | |||
| } | |||
| param->channel_ = 1; | |||
| param->out_count_ = 1; | |||
| param->in_stride_ = 1; | |||
| int cur_axis; | |||
| for (cur_axis = 0; cur_axis < param->axis_; cur_axis++) { | |||
| param->out_count_ *= in_tensor->shape()[cur_axis]; | |||
| param->outer_size_ = 1; | |||
| param->axis_size_ = 1; | |||
| param->inner_size_ = 1; | |||
| for (int i = 0; i < param->axis_; i++) { | |||
| param->outer_size_ *= in_shape[i]; | |||
| } | |||
| for (int i = 0; i < param->num_axis_; i++) { | |||
| param->channel_ *= in_tensor->shape()[(cur_axis++)]; | |||
| for (int i = 0; i < scale_shape.size(); i++) { | |||
| if (in_shape[i + param->axis_] != scale_shape[i]) { | |||
| MS_LOG(ERROR) << "Scale tensor shape is incorrect."; | |||
| return RET_ERROR; | |||
| } | |||
| param->axis_size_ *= in_shape[i + param->axis_]; | |||
| } | |||
| for (int i = cur_axis; i < in_tensor->shape().size(); i++) { | |||
| param->in_stride_ *= in_tensor->shape()[cur_axis]; | |||
| for (int i = param->axis_ + scale_shape.size(); i < in_shape.size(); i++) { | |||
| param->inner_size_ *= in_shape[i]; | |||
| } | |||
| if (scale->shape().back() != param->channel_ || scale->shape().size() > 2) { | |||
| MS_LOG(ERROR) << "scale shape illegal."; | |||
| return RET_OK; | |||
| } | |||
| int ScaleCPUKernel::Init() { | |||
| if (inputs_.size() < 2 || inputs_.size() > 3) { | |||
| MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << inputs_.size() << " is given."; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs_.size() == 3) { | |||
| if ((inputs_.at(2))->shape().back() != param->channel_ || (inputs_.at(2))->shape().size() > 2) { | |||
| MS_LOG(ERROR) << "offset shape illegal."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| input_ptr_ = reinterpret_cast<float *>(inputs_.front()->Data()); | |||
| scale_ = reinterpret_cast<float *>(inputs_.at(1)->Data()); | |||
| if (inputs_.size() == 3) { | |||
| offset_ = reinterpret_cast<float *>(inputs_.at(2)->Data()); | |||
| has_offset_ = true; | |||
| } else { | |||
| offset_ = nullptr; | |||
| has_offset_ = false; | |||
| auto ret = InitParameter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale fp32 InitParameter failed."; | |||
| return RET_ERROR; | |||
| } | |||
| output_ptr_ = reinterpret_cast<float *>(outputs_.front()->Data()); | |||
| num_unit_ = param->out_count_ * param->channel_; | |||
| unit_size_ = param->in_stride_; | |||
| thread_n_num_ = MSMIN(thread_num_, num_unit_); | |||
| thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); | |||
| ret = InitScaleOffset(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ScaleCPUKernel::ReSize() { return RET_OK; } | |||
| int ScaleCPUKernel::Scale(int task_id) { | |||
| int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); | |||
| if (num_unit_thread <= 0) { | |||
| return RET_OK; | |||
| } | |||
| int thread_offset = task_id * thread_n_stride_; | |||
| int ret; | |||
| if (has_offset_) { | |||
| ret = DoScale(input_ptr_, output_ptr_, scale_, offset_, thread_offset, num_unit_thread, | |||
| reinterpret_cast<ScaleParameter *>(opParameter)); | |||
| } else { | |||
| ret = DoScale(input_ptr_, output_ptr_, scale_, thread_offset, num_unit_thread, | |||
| reinterpret_cast<ScaleParameter *>(opParameter)); | |||
| } | |||
| auto ret = | |||
| DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, reinterpret_cast<ScaleParameter *>(opParameter)); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| @@ -116,11 +122,9 @@ int ScaleCPUKernel::Scale(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| int ScaleCPUKernel::ReSize() { return RET_OK; } | |||
| int ScaleRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto g_kernel = reinterpret_cast<ScaleCPUKernel *>(cdata); | |||
| auto ret = g_kernel->Scale(task_id); | |||
| auto scale = reinterpret_cast<ScaleCPUKernel *>(cdata); | |||
| auto ret = scale->Scale(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ScaleRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| @@ -129,7 +133,16 @@ int ScaleRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| } | |||
| int ScaleCPUKernel::Run() { | |||
| int ret = LiteBackendParallelLaunch(ScaleRun, this, thread_n_num_); | |||
| auto in_tensor = inputs_.front(); | |||
| input_ptr_ = reinterpret_cast<float *>(in_tensor->Data()); | |||
| if (scale_ == nullptr) { | |||
| auto scale_tensor = inputs_[1]; | |||
| scale_ = reinterpret_cast<float *>(scale_tensor->Data()); | |||
| } | |||
| auto out_tensor = outputs_.front(); | |||
| output_ptr_ = reinterpret_cast<float *>(out_tensor->Data()); | |||
| int ret = LiteBackendParallelLaunch(ScaleRun, this, opParameter->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| @@ -160,7 +173,6 @@ kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector<lite::tensor::Te | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -26,27 +26,24 @@ class ScaleCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit ScaleCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) | |||
| : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} | |||
| : LiteKernel(parameter, inputs, outputs) { | |||
| opParameter->thread_num_ = ctx->thread_num_; | |||
| } | |||
| ~ScaleCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int InitParameter(); | |||
| int InitScaleOffset(); | |||
| int Scale(int task_id); | |||
| private: | |||
| int thread_num_; | |||
| int thread_n_stride_; | |||
| int thread_n_num_; | |||
| int num_unit_; | |||
| int unit_size_; | |||
| float *input_ptr_; | |||
| float *scale_; | |||
| float *offset_; | |||
| float *output_ptr_; | |||
| bool has_offset_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCALE_H_ | |||
| @@ -17,37 +17,33 @@ | |||
| #include "src/runtime/kernel/arm/opclib/scale.h" | |||
| #include "src/runtime/kernel/arm/opclib/errorcode.h" | |||
| int DoScale(float *in_data, float *out_data, float *scale, float *offset, int units_offset, int num_unit, | |||
| ScaleParameter *scale_param) { | |||
| int DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param) { | |||
| if (in_data == nullptr || out_data == nullptr || scale == nullptr || offset == nullptr || scale_param == nullptr) { | |||
| return OPCLIB_ERR; | |||
| } | |||
| int in_stride_j = units_offset * scale_param->in_stride_; | |||
| for (int j = units_offset; j < units_offset + num_unit; j++) { | |||
| int channel = j % scale_param->channel_; | |||
| for (int k = 0; k < scale_param->in_stride_; k++) { | |||
| out_data[in_stride_j + k] = in_data[in_stride_j + k] * scale[channel] + offset[channel]; | |||
| if (scale_param->has_offset_) { | |||
| for (int out = task_id; out < scale_param->outer_size_; out += scale_param->op_parameter_.thread_num_) { | |||
| int out_offset = out * scale_param->axis_size_ * scale_param->inner_size_; | |||
| for (int i = 0; i < scale_param->axis_size_; i++) { | |||
| int axis_offset = out_offset + i * scale_param->inner_size_; | |||
| for (int in = 0; in < scale_param->inner_size_; in++) { | |||
| int in_offset = axis_offset + in; | |||
| out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; | |||
| } | |||
| } | |||
| } | |||
| in_stride_j = in_stride_j + scale_param->in_stride_; | |||
| } | |||
| return OPCLIB_OK; | |||
| } | |||
| int DoScale(float *in_data, float *out_data, float *scale, int units_offset, int num_unit, | |||
| ScaleParameter *scale_param) { | |||
| if (in_data == nullptr || out_data == nullptr || scale == nullptr || scale_param == nullptr) { | |||
| return OPCLIB_ERR; | |||
| } | |||
| int in_stride_j = units_offset * scale_param->in_stride_; | |||
| for (int j = units_offset; j < units_offset + num_unit; j++) { | |||
| int channel = j % scale_param->channel_; | |||
| for (int k = 0; k < scale_param->in_stride_; k++) { | |||
| out_data[in_stride_j + k] = in_data[in_stride_j + k] * scale[channel]; | |||
| } else { | |||
| for (int out = task_id; out < scale_param->outer_size_; out += scale_param->op_parameter_.thread_num_) { | |||
| int out_offset = out * scale_param->axis_size_ * scale_param->inner_size_; | |||
| for (int i = 0; i < scale_param->axis_size_; i++) { | |||
| int axis_offset = out_offset + i * scale_param->inner_size_; | |||
| for (int in = 0; in < scale_param->inner_size_; in++) { | |||
| int in_offset = axis_offset + in; | |||
| out_data[in_offset] = in_data[in_offset] * scale[i]; | |||
| } | |||
| } | |||
| } | |||
| in_stride_j = in_stride_j + scale_param->in_stride_; | |||
| } | |||
| return OPCLIB_OK; | |||
| } | |||
| @@ -21,15 +21,13 @@ | |||
| struct ScaleParameter { | |||
| OpParameter op_parameter_; | |||
| int out_count_; | |||
| int channel_; | |||
| int in_stride_; | |||
| int outer_size_; | |||
| int axis_size_; | |||
| int inner_size_; | |||
| int axis_; | |||
| int num_axis_; | |||
| bool has_offset_; | |||
| // todo yangruoqi: axis | |||
| }; | |||
| int DoScale(float *in_data, float *out_data, float *scale, float *offset, int units_offset, int num_unit, | |||
| ScaleParameter *scale_param); | |||
| int DoScale(float *in_data, float *out_data, float *scale, int units_offset, int num_unit, ScaleParameter *scale_param); | |||
| int DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SCALE_H_ | |||
| @@ -22,12 +22,9 @@ const int32_t DIM_DEFAULT_SIZE = 4; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, | |||
| const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) { | |||
| STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| std::unique_ptr<schema::ScaleT> attr(new schema::ScaleT()); | |||
| attr->format = schema::Format_NCHW; | |||
| if (weight.blobs_size() + weight.bottom_size() < 2) { | |||
| // MS_LOGE("Scale bottom size:%d, blobs size:%d invalid in layer %s", weight.bottom_size(), weight.blobs_size(), | |||
| @@ -36,12 +33,14 @@ STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, | |||
| } | |||
| const caffe::ScaleParameter scaleParam = weight.scale_param(); | |||
| int32_t axis = scaleParam.axis(); // NCHW_DIM_C; | |||
| uint32_t axis_index = NCHW_DIM_C; | |||
| if (GetAxisIndex(axis, &axis_index)) { | |||
| // MS_LOGE("scale get axis failed for layer %s.", weight.name().c_str()); | |||
| int axis = NCHW_DIM_C; | |||
| if (scaleParam.has_axis()) { | |||
| uint32_t axis_index = NCHW_DIM_C; | |||
| if (GetAxisIndex(scaleParam.axis(), &axis_index)) { | |||
| // MS_LOGE("scale get axis failed for layer %s.", weight.name().c_str()); | |||
| } | |||
| } | |||
| attr->axis = axis; | |||
| // parse scale | |||
| // todo expect only weight as scale not bias | |||
| @@ -94,4 +93,3 @@ STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) | |||
| CaffeNodeRegistrar g_caffeScaleParser("Scale", new CaffeScaleParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||