Browse Source

constant of shape support int64

tags/v1.2.0-rc1
ling 5 years ago
parent
commit
e788f00976
8 changed files with 114 additions and 103 deletions
  1. +31
    -0
      mindspore/lite/nnacl/constant_of_shape.c
  2. +7
    -5
      mindspore/lite/nnacl/constant_of_shape.h
  3. +0
    -39
      mindspore/lite/nnacl/fp32/constant_of_shape_fp32.c
  4. +26
    -9
      mindspore/lite/src/ops/constant_of_shape.cc
  5. +12
    -3
      mindspore/lite/src/ops/populate/constant_of_shape_populate.cc
  6. +26
    -36
      mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc
  7. +9
    -9
      mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h
  8. +3
    -2
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc

+ 31
- 0
mindspore/lite/nnacl/constant_of_shape.c View File

@@ -0,0 +1,31 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "nnacl/constant_of_shape.h"

int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value) {
for (int i = start; i < end; i++) {
output[i] = value;
}
return NNACL_OK;
}

int ConstantOfShapeFp32(float *output, int start, int end, float value) {
for (int i = start; i < end; i++) {
output[i] = value;
}
return NNACL_OK;
}

mindspore/lite/nnacl/fp32/constant_of_shape_fp32.h → mindspore/lite/nnacl/constant_of_shape.h View File

@@ -24,17 +24,19 @@


typedef struct ConstantOfShapeParameter { typedef struct ConstantOfShapeParameter {
OpParameter op_parameter_; OpParameter op_parameter_;
float value_;
union value_ {
float f32_value_;
int32_t int32_value_;
} value_;
int data_type_; int data_type_;
int unit_;
int element_sz_;
int element_size_;
} ConstantOfShapeParameter; } ConstantOfShapeParameter;


#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int ConstantOfShape(float *output, int tid, const ConstantOfShapeParameter *param);
int ConstantOfShapeInt(int32_t *output, int tid, const ConstantOfShapeParameter *param);
int ConstantOfShapeFp32(float *output, int start, int end, float value);
int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

+ 0
- 39
mindspore/lite/nnacl/fp32/constant_of_shape_fp32.c View File

@@ -1,39 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "nnacl/fp32/constant_of_shape_fp32.h"

int ConstantOfShape(float *output, int tid, const ConstantOfShapeParameter *param) {
int size = param->unit_;
float data = param->value_;
int ind_st = MSMIN(tid * size, param->element_sz_);
int ind_end = MSMIN(param->element_sz_, (tid + 1) * size);
for (int i = ind_st; i < ind_end; ++i) {
output[i] = data;
}
return NNACL_OK;
}

int ConstantOfShapeInt(int32_t *output, int tid, const ConstantOfShapeParameter *param) {
int size = param->unit_;
float data = param->value_;
int ind_st = MSMIN(tid * size, param->element_sz_);
int ind_end = MSMIN(param->element_sz_, (tid + 1) * size);
for (int i = ind_st; i < ind_end; ++i) {
output[i] = data;
}
return NNACL_OK;
}

+ 26
- 9
mindspore/lite/src/ops/constant_of_shape.cc View File

@@ -78,25 +78,42 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
MS_LOG(ERROR) << "outputs to ConstantOfShape operator should be 1, but " << outputs_.size() << " is given."; MS_LOG(ERROR) << "outputs to ConstantOfShape operator should be 1, but " << outputs_.size() << " is given.";
return RET_ERROR; return RET_ERROR;
} }

auto in_tensor = inputs_.front(); auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front(); auto out_tensor = outputs_.front();
out_tensor->set_data_type(static_cast<TypeId>(GetDataType())); out_tensor->set_data_type(static_cast<TypeId>(GetDataType()));
out_tensor->set_format(in_tensor->format()); out_tensor->set_format(in_tensor->format());
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto in_data = reinterpret_cast<int *>(in_tensor->data_c());
if (in_data == nullptr) {
MS_LOG(INFO) << "Input data is nullptr. Input tensor has not been calculated out yet.";

if (!infer_flag() || in_tensor->data_c() == nullptr) {
return RET_INFER_INVALID; return RET_INFER_INVALID;
} }

int size = in_tensor->ElementsNum(); int size = in_tensor->ElementsNum();
std::vector<int> out_shape(size); std::vector<int> out_shape(size);
for (int i = 0; i < size; ++i) {
out_shape[i] = in_data[i];

switch (in_tensor->data_type()) {
case kNumberTypeInt32: {
int32_t *in_data = reinterpret_cast<int32_t *>(in_tensor->data_c());
for (int i = 0; i < size; ++i) {
out_shape[i] = in_data[i];
MS_ASSERT(out_shape[i] > 0);
}
break;
}
case kNumberTypeInt64: {
int64_t *in_data = reinterpret_cast<int64_t *>(in_tensor->data_c());
for (int i = 0; i < size; ++i) {
out_shape[i] = in_data[i];
MS_ASSERT(out_shape[i] > 0);
}
break;
}
default:
MS_LOG(INFO) << "Invalid input data type!";
return RET_INFER_INVALID;
} }
out_tensor->set_shape(out_shape);


out_tensor->set_shape(out_shape);
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

+ 12
- 3
mindspore/lite/src/ops/populate/constant_of_shape_populate.cc View File

@@ -19,7 +19,7 @@
#include "src/tensor.h" #include "src/tensor.h"
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "nnacl/fp32/constant_of_shape_fp32.h"
#include "nnacl/constant_of_shape.h"


namespace mindspore::lite { namespace mindspore::lite {
namespace { namespace {
@@ -34,13 +34,22 @@ OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC
} }
memset(param, 0, sizeof(ConstantOfShapeParameter)); memset(param, 0, sizeof(ConstantOfShapeParameter));
param->op_parameter_.type_ = primitive->Type(); param->op_parameter_.type_ = primitive->Type();
param->data_type_ = attr->GetDataType();
auto value = attr->GetValue(); auto value = attr->GetValue();
if (value.empty() || value.size() > 1) { if (value.empty() || value.size() > 1) {
MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1.";
} else { } else {
param->value_ = attr->GetValue().at(0);
switch (param->data_type_) {
case kNumberTypeFloat32:
param->value_.f32_value_ = attr->GetValue().at(0);
break;
case kNumberTypeInt32:
param->value_.int32_value_ = attr->GetValue().at(0);
break;
default:
MS_LOG(ERROR) << "The value of constant of shape is invalid";
}
} }
param->data_type_ = attr->GetDataType();
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }
Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter);


mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.cc → mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc View File

@@ -14,11 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */


#include "src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h"
#include <vector>
#include "src/runtime/kernel/arm/base/constant_of_shape.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
@@ -28,30 +26,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_ConstantOfShape; using mindspore::schema::PrimitiveType_ConstantOfShape;


namespace mindspore::kernel { namespace mindspore::kernel {
int ConstantOfShapeCPUKernel::Init() { return RET_OK; }

int ConstantOfShapeCPUKernel::ReSize() { return RET_OK; }

int ConstantOfShapeCPUKernel::DoExecute(int task_id) {
int ret = RET_ERROR;
switch (param_->data_type_) {
case kNumberTypeFloat32:
ret = ConstantOfShape(reinterpret_cast<float *>(out_ptr_), task_id, param_);
break;
case kNumberTypeInt32:
ret = ConstantOfShapeInt(reinterpret_cast<int32_t *>(out_ptr_), task_id, param_);
break;
default:
MS_LOG(ERROR) << "Constant of shape does not support the output data type.";
return RET_ERROR;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}

int ConstantOfShapeRun(void *cdata, int task_id) { int ConstantOfShapeRun(void *cdata, int task_id) {
auto g_kernel = reinterpret_cast<ConstantOfShapeCPUKernel *>(cdata); auto g_kernel = reinterpret_cast<ConstantOfShapeCPUKernel *>(cdata);
auto ret = g_kernel->DoExecute(task_id); auto ret = g_kernel->DoExecute(task_id);
@@ -62,23 +36,38 @@ int ConstantOfShapeRun(void *cdata, int task_id) {
return RET_OK; return RET_OK;
} }


int ConstantOfShapeCPUKernel::Run() {
param_->element_sz_ = out_tensors_.front()->ElementsNum();
int thread_num = MSMIN(param_->op_parameter_.thread_num_, param_->element_sz_);
param_->unit_ = UP_DIV(param_->element_sz_, thread_num);
param_->op_parameter_.thread_num_ = thread_num;
int ConstantOfShapeCPUKernel::DoExecute(int task_id) {
int start = task_id * thread_stride_;
int current_stride = MSMIN(thread_stride_, param_->element_size_ - start);
if (current_stride < 0) {
return RET_OK;
}

switch (param_->data_type_) { switch (param_->data_type_) {
case kNumberTypeFloat32: case kNumberTypeFloat32:
out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
ConstantOfShapeFp32(reinterpret_cast<float *>(output_ptr_), start, start + current_stride,
param_->value_.f32_value_);
break; break;
case kNumberTypeInt32: case kNumberTypeInt32:
out_ptr_ = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData());
ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride,
param_->value_.int32_value_);
break; break;
default: default:
MS_LOG(ERROR) << "Constant of shape does not support the output data type.";
MS_LOG(ERROR) << "Invalid datatype in ConstantOfShapeRun";
return RET_ERROR; return RET_ERROR;
} }
auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_num);
return RET_OK;
}

int ConstantOfShapeCPUKernel::Run() {
auto output = out_tensors_.front();
param_->data_type_ = output->data_type();
param_->element_size_ = output->ElementsNum();
output_ptr_ = output->data_c();
int thread_count = MSMIN(op_parameter_->thread_num_, param_->element_size_);
thread_stride_ = UP_DIV(param_->element_size_, thread_count);

auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_count);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]"; MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]";
return ret; return ret;
@@ -88,4 +77,5 @@ int ConstantOfShapeCPUKernel::Run() {


REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
} // namespace mindspore::kernel } // namespace mindspore::kernel

mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h → mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h View File

@@ -13,15 +13,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_


#include <vector> #include <vector>
#include "include/errorcode.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "include/context.h" #include "include/context.h"
#include "nnacl/fp32/constant_of_shape_fp32.h"

using mindspore::lite::InnerContext;
#include "nnacl/constant_of_shape.h"


namespace mindspore::kernel { namespace mindspore::kernel {
class ConstantOfShapeCPUKernel : public LiteKernel { class ConstantOfShapeCPUKernel : public LiteKernel {
@@ -34,15 +33,16 @@ class ConstantOfShapeCPUKernel : public LiteKernel {
} }
~ConstantOfShapeCPUKernel() override = default; ~ConstantOfShapeCPUKernel() override = default;


int Init() override;
int ReSize() override;
int Init() override { return lite::RET_OK; }
int ReSize() override { return lite::RET_OK; }
int Run() override; int Run() override;
int DoExecute(int task_id); int DoExecute(int task_id);


private: private:
ConstantOfShapeParameter *param_ = nullptr; ConstantOfShapeParameter *param_ = nullptr;
void *out_ptr_ = nullptr;
void *output_ptr_ = nullptr;
int thread_stride_;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel


#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_CONSTANT_OF_SHAPE_H_

+ 3
- 2
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc View File

@@ -15,7 +15,7 @@
*/ */
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "common/common_test.h" #include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape_fp32.h"
#include "mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"


@@ -47,7 +47,8 @@ TEST_F(TestConstantOfShapeFp32, Simple) {
std::vector<lite::Tensor *> inputs_; std::vector<lite::Tensor *> inputs_;
std::vector<lite::Tensor *> outputs_; std::vector<lite::Tensor *> outputs_;
auto param = new ConstantOfShapeParameter(); auto param = new ConstantOfShapeParameter();
param->value_ = 1;
param->value_.f32_value_ = 1;
param->data_type_ = kNumberTypeFloat32;
float a[] = {1, 2, 3, 4}; float a[] = {1, 2, 3, 4};
std::vector<int> a_shape = {4, 1, 1, 1}; std::vector<int> a_shape = {4, 1, 1, 1};
// std::vector<int> c_shape = {2, 2, 2, 1}; // std::vector<int> c_shape = {2, 2, 2, 1};


Loading…
Cancel
Save