| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/constant_of_shape.h" | |||||
| int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value) { | |||||
| for (int i = start; i < end; i++) { | |||||
| output[i] = value; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ConstantOfShapeFp32(float *output, int start, int end, float value) { | |||||
| for (int i = start; i < end; i++) { | |||||
| output[i] = value; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -24,17 +24,19 @@ | |||||
| typedef struct ConstantOfShapeParameter { | typedef struct ConstantOfShapeParameter { | ||||
| OpParameter op_parameter_; | OpParameter op_parameter_; | ||||
| float value_; | |||||
| union value_ { | |||||
| float f32_value_; | |||||
| int32_t int32_value_; | |||||
| } value_; | |||||
| int data_type_; | int data_type_; | ||||
| int unit_; | |||||
| int element_sz_; | |||||
| int element_size_; | |||||
| } ConstantOfShapeParameter; | } ConstantOfShapeParameter; | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| int ConstantOfShape(float *output, int tid, const ConstantOfShapeParameter *param); | |||||
| int ConstantOfShapeInt(int32_t *output, int tid, const ConstantOfShapeParameter *param); | |||||
| int ConstantOfShapeFp32(float *output, int start, int end, float value); | |||||
| int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -1,39 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32/constant_of_shape_fp32.h" | |||||
| int ConstantOfShape(float *output, int tid, const ConstantOfShapeParameter *param) { | |||||
| int size = param->unit_; | |||||
| float data = param->value_; | |||||
| int ind_st = MSMIN(tid * size, param->element_sz_); | |||||
| int ind_end = MSMIN(param->element_sz_, (tid + 1) * size); | |||||
| for (int i = ind_st; i < ind_end; ++i) { | |||||
| output[i] = data; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ConstantOfShapeInt(int32_t *output, int tid, const ConstantOfShapeParameter *param) { | |||||
| int size = param->unit_; | |||||
| float data = param->value_; | |||||
| int ind_st = MSMIN(tid * size, param->element_sz_); | |||||
| int ind_end = MSMIN(param->element_sz_, (tid + 1) * size); | |||||
| for (int i = ind_st; i < ind_end; ++i) { | |||||
| output[i] = data; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -78,25 +78,42 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso | |||||
| MS_LOG(ERROR) << "outputs to ConstantOfShape operator should be 1, but " << outputs_.size() << " is given."; | MS_LOG(ERROR) << "outputs to ConstantOfShape operator should be 1, but " << outputs_.size() << " is given."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto in_tensor = inputs_.front(); | auto in_tensor = inputs_.front(); | ||||
| auto out_tensor = outputs_.front(); | auto out_tensor = outputs_.front(); | ||||
| out_tensor->set_data_type(static_cast<TypeId>(GetDataType())); | out_tensor->set_data_type(static_cast<TypeId>(GetDataType())); | ||||
| out_tensor->set_format(in_tensor->format()); | out_tensor->set_format(in_tensor->format()); | ||||
| if (!infer_flag()) { | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| auto in_data = reinterpret_cast<int *>(in_tensor->data_c()); | |||||
| if (in_data == nullptr) { | |||||
| MS_LOG(INFO) << "Input data is nullptr. Input tensor has not been calculated out yet."; | |||||
| if (!infer_flag() || in_tensor->data_c() == nullptr) { | |||||
| return RET_INFER_INVALID; | return RET_INFER_INVALID; | ||||
| } | } | ||||
| int size = in_tensor->ElementsNum(); | int size = in_tensor->ElementsNum(); | ||||
| std::vector<int> out_shape(size); | std::vector<int> out_shape(size); | ||||
| for (int i = 0; i < size; ++i) { | |||||
| out_shape[i] = in_data[i]; | |||||
| switch (in_tensor->data_type()) { | |||||
| case kNumberTypeInt32: { | |||||
| int32_t *in_data = reinterpret_cast<int32_t *>(in_tensor->data_c()); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| out_shape[i] = in_data[i]; | |||||
| MS_ASSERT(out_shape[i] > 0); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case kNumberTypeInt64: { | |||||
| int64_t *in_data = reinterpret_cast<int64_t *>(in_tensor->data_c()); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| out_shape[i] = in_data[i]; | |||||
| MS_ASSERT(out_shape[i] > 0); | |||||
| } | |||||
| break; | |||||
| } | |||||
| default: | |||||
| MS_LOG(INFO) << "Invalid input data type!"; | |||||
| return RET_INFER_INVALID; | |||||
| } | } | ||||
| out_tensor->set_shape(out_shape); | |||||
| out_tensor->set_shape(out_shape); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/fp32/constant_of_shape_fp32.h" | |||||
| #include "nnacl/constant_of_shape.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| namespace { | namespace { | ||||
| @@ -34,13 +34,22 @@ OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC | |||||
| } | } | ||||
| memset(param, 0, sizeof(ConstantOfShapeParameter)); | memset(param, 0, sizeof(ConstantOfShapeParameter)); | ||||
| param->op_parameter_.type_ = primitive->Type(); | param->op_parameter_.type_ = primitive->Type(); | ||||
| param->data_type_ = attr->GetDataType(); | |||||
| auto value = attr->GetValue(); | auto value = attr->GetValue(); | ||||
| if (value.empty() || value.size() > 1) { | if (value.empty() || value.size() > 1) { | ||||
| MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; | MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; | ||||
| } else { | } else { | ||||
| param->value_ = attr->GetValue().at(0); | |||||
| switch (param->data_type_) { | |||||
| case kNumberTypeFloat32: | |||||
| param->value_.f32_value_ = attr->GetValue().at(0); | |||||
| break; | |||||
| case kNumberTypeInt32: | |||||
| param->value_.int32_value_ = attr->GetValue().at(0); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "The value of constant of shape is invalid"; | |||||
| } | |||||
| } | } | ||||
| param->data_type_ = attr->GetDataType(); | |||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); | Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); | ||||
| @@ -14,11 +14,9 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h" | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/base/constant_of_shape.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -28,30 +26,6 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_ConstantOfShape; | using mindspore::schema::PrimitiveType_ConstantOfShape; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConstantOfShapeCPUKernel::Init() { return RET_OK; } | |||||
| int ConstantOfShapeCPUKernel::ReSize() { return RET_OK; } | |||||
| int ConstantOfShapeCPUKernel::DoExecute(int task_id) { | |||||
| int ret = RET_ERROR; | |||||
| switch (param_->data_type_) { | |||||
| case kNumberTypeFloat32: | |||||
| ret = ConstantOfShape(reinterpret_cast<float *>(out_ptr_), task_id, param_); | |||||
| break; | |||||
| case kNumberTypeInt32: | |||||
| ret = ConstantOfShapeInt(reinterpret_cast<int32_t *>(out_ptr_), task_id, param_); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Constant of shape does not support the output data type."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConstantOfShapeRun(void *cdata, int task_id) { | int ConstantOfShapeRun(void *cdata, int task_id) { | ||||
| auto g_kernel = reinterpret_cast<ConstantOfShapeCPUKernel *>(cdata); | auto g_kernel = reinterpret_cast<ConstantOfShapeCPUKernel *>(cdata); | ||||
| auto ret = g_kernel->DoExecute(task_id); | auto ret = g_kernel->DoExecute(task_id); | ||||
| @@ -62,23 +36,38 @@ int ConstantOfShapeRun(void *cdata, int task_id) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConstantOfShapeCPUKernel::Run() { | |||||
| param_->element_sz_ = out_tensors_.front()->ElementsNum(); | |||||
| int thread_num = MSMIN(param_->op_parameter_.thread_num_, param_->element_sz_); | |||||
| param_->unit_ = UP_DIV(param_->element_sz_, thread_num); | |||||
| param_->op_parameter_.thread_num_ = thread_num; | |||||
| int ConstantOfShapeCPUKernel::DoExecute(int task_id) { | |||||
| int start = task_id * thread_stride_; | |||||
| int current_stride = MSMIN(thread_stride_, param_->element_size_ - start); | |||||
| if (current_stride < 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| switch (param_->data_type_) { | switch (param_->data_type_) { | ||||
| case kNumberTypeFloat32: | case kNumberTypeFloat32: | ||||
| out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData()); | |||||
| ConstantOfShapeFp32(reinterpret_cast<float *>(output_ptr_), start, start + current_stride, | |||||
| param_->value_.f32_value_); | |||||
| break; | break; | ||||
| case kNumberTypeInt32: | case kNumberTypeInt32: | ||||
| out_ptr_ = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData()); | |||||
| ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride, | |||||
| param_->value_.int32_value_); | |||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Constant of shape does not support the output data type."; | |||||
| MS_LOG(ERROR) << "Invalid datatype in ConstantOfShapeRun"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_num); | |||||
| return RET_OK; | |||||
| } | |||||
| int ConstantOfShapeCPUKernel::Run() { | |||||
| auto output = out_tensors_.front(); | |||||
| param_->data_type_ = output->data_type(); | |||||
| param_->element_size_ = output->ElementsNum(); | |||||
| output_ptr_ = output->data_c(); | |||||
| int thread_count = MSMIN(op_parameter_->thread_num_, param_->element_size_); | |||||
| thread_stride_ = UP_DIV(param_->element_size_, thread_count); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_count); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]"; | MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]"; | ||||
| return ret; | return ret; | ||||
| @@ -88,4 +77,5 @@ int ConstantOfShapeCPUKernel::Run() { | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) | REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -13,15 +13,14 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "include/errorcode.h" | |||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "include/context.h" | #include "include/context.h" | ||||
| #include "nnacl/fp32/constant_of_shape_fp32.h" | |||||
| using mindspore::lite::InnerContext; | |||||
| #include "nnacl/constant_of_shape.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class ConstantOfShapeCPUKernel : public LiteKernel { | class ConstantOfShapeCPUKernel : public LiteKernel { | ||||
| @@ -34,15 +33,16 @@ class ConstantOfShapeCPUKernel : public LiteKernel { | |||||
| } | } | ||||
| ~ConstantOfShapeCPUKernel() override = default; | ~ConstantOfShapeCPUKernel() override = default; | ||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Init() override { return lite::RET_OK; } | |||||
| int ReSize() override { return lite::RET_OK; } | |||||
| int Run() override; | int Run() override; | ||||
| int DoExecute(int task_id); | int DoExecute(int task_id); | ||||
| private: | private: | ||||
| ConstantOfShapeParameter *param_ = nullptr; | ConstantOfShapeParameter *param_ = nullptr; | ||||
| void *out_ptr_ = nullptr; | |||||
| void *output_ptr_ = nullptr; | |||||
| int thread_stride_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_ | |||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h" | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -47,7 +47,8 @@ TEST_F(TestConstantOfShapeFp32, Simple) { | |||||
| std::vector<lite::Tensor *> inputs_; | std::vector<lite::Tensor *> inputs_; | ||||
| std::vector<lite::Tensor *> outputs_; | std::vector<lite::Tensor *> outputs_; | ||||
| auto param = new ConstantOfShapeParameter(); | auto param = new ConstantOfShapeParameter(); | ||||
| param->value_ = 1; | |||||
| param->value_.f32_value_ = 1; | |||||
| param->data_type_ = kNumberTypeFloat32; | |||||
| float a[] = {1, 2, 3, 4}; | float a[] = {1, 2, 3, 4}; | ||||
| std::vector<int> a_shape = {4, 1, 1, 1}; | std::vector<int> a_shape = {4, 1, 1, 1}; | ||||
| // std::vector<int> c_shape = {2, 2, 2, 1}; | // std::vector<int> c_shape = {2, 2, 2, 1}; | ||||