Browse Source

!22343 [MS][LITE] add scatter_nd_update ops

Merge pull request !22343 from mengyuanli/add_scatter_update_ops
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
29d5d61d8a
8 changed files with 222 additions and 15 deletions
  1. +0
    -7
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.c
  2. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc
  3. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_fp32.cc
  4. +165
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.cc
  5. +49
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h
  6. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.cc
  7. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc
  8. +1
    -1
      mindspore/lite/tools/optimizer/format/delete_redundant_transpose.cc

+ 0
- 7
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/scatter_nd_update_infer.c View File

@@ -25,18 +25,11 @@ int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size,
}

const TensorC *input_x = inputs[0];
const TensorC *indices = inputs[1];
const TensorC *update = inputs[2];
TensorC *output = outputs[0];

SetDataTypeFormat(output, input_x);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
if (indices->shape_size_ != update->shape_size_) {
return NNACL_ERR;
}

SetShapeArray(output, input_x->shape_, input_x->shape_size_);
return NNACL_OK;
}


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc View File

@@ -53,8 +53,8 @@ int BroadcastToCPUKernel::ReSize() {
}

int BroadcastToCPUKernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), 1);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_LESS_RETURN(in_tensors_.size(), 1)
CHECK_LESS_RETURN(out_tensors_.size(), 1)
shape_info_ = reinterpret_cast<BroadcastShapeInfo *>(malloc(sizeof(BroadcastShapeInfo)));
if (shape_info_ == nullptr) {
MS_LOG(ERROR) << "Malloc BroadcastShapeInfo failed!";


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_fp32.cc View File

@@ -34,8 +34,8 @@ constexpr int kScatterIndicesIndex = 1;
constexpr int kScatterUpdateIndex = 2;
} // namespace
int ScatterNDCPUKernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D)
CHECK_LESS_RETURN(out_tensors_.size(), 1)
if (!InferShapeDone()) {
return RET_OK;
}


+ 165
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.cc View File

@@ -0,0 +1,165 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h"
#include <cstring>
#include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_ScatterNdUpdate;

namespace mindspore::kernel {
namespace {
constexpr int kScatterUpdateInputIndex = 0;
constexpr int kScatterIndicesIndex = 1;
constexpr int kScatterUpdateIndex = 2;
constexpr size_t kScatterIndicesDims = 2;
} // namespace
int ScatterNdUpdateCPUKernel::Init() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D)
CHECK_LESS_RETURN(out_tensors_.size(), 1)
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int ScatterNdUpdateCPUKernel::ReSize() {
auto input = in_tensors_.at(kScatterUpdateInputIndex);
auto indices = in_tensors_.at(kScatterIndicesIndex);
auto update = in_tensors_.at(kScatterUpdateIndex);
auto output = out_tensors_.front();

update_ptr_ = reinterpret_cast<float *>(update->MutableData());
MS_ASSERT(update_ptr_ != nullptr);
output_ptr_ = reinterpret_cast<float *>(output->MutableData());
MS_ASSERT(output_ptr_ != nullptr);

// check indices shape
int input_rank = static_cast<int>(input->shape().size());
int indice_unit_rank = indices->shape().back();
if (indice_unit_rank > input_rank) {
MS_LOG(ERROR) << "Value of last dimension of indices is greater than input rank.";
return RET_ERROR;
}

if (indices->shape().size() < kScatterIndicesDims) {
MS_LOG(ERROR) << "Indices dimension smaller than 2.";
return RET_ERROR;
}

// check consistency of the shape indices and shape
int update_rank = static_cast<int>(update->shape().size());
auto indices_shape = indices->shape();
auto update_shape = update->shape();
unit_size_ = 1;
for (int i = indices_shape.size() - 1; i < update_rank; i++) {
unit_size_ *= update_shape.at(i);
}

// calculate offsets
int out_stride = 1;
out_strides_.push_back(1);
for (int i = indice_unit_rank - 2; i >= 0; i--) {
out_stride *= input->shape()[i + 1];
out_strides_.push_back(out_stride);
}
std::reverse(out_strides_.begin(), out_strides_.end());

num_unit_ = 1;
num_unit_ *= update_shape.at(indices_shape.size() - 2);
for (int i = indices_shape.size() - 3; i >= 0; i--) {
num_unit_ *= update_shape.at(i);
}

int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
MS_ASSERT(indices_ptr != nullptr);
output_unit_offsets_.clear();
for (int i = 0; i < num_unit_; i++) {
int tmp_stride = 0;
for (int j = 0; j < indice_unit_rank; j++) {
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
}
output_unit_offsets_.push_back(tmp_stride);
}

thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
if (thread_n_num_ == 0) {
return RET_ERROR;
}
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
return RET_OK;
}

int ScatterNdUpdateCPUKernel::ScatterNdUpdate(int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int offset = task_id * thread_n_stride_;
auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
unit_size_, num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdate error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
in_tensors_.at(kScatterUpdateInputIndex)->IncRefCount();
return RET_OK;
}

int ScatterNdUpdateRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto g_kernel = reinterpret_cast<ScatterNdUpdateCPUKernel *>(cdata);
MS_ASSERT(g_kernel != nullptr);
auto ret = g_kernel->ScatterNdUpdate(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdateRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}

int ScatterNdUpdateCPUKernel::Run() {
auto in_tensor = in_tensors().front();
auto out_tensor = out_tensors().front();
if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() ||
op_parameter_->is_train_session_) {
memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
} else {
out_tensor->FreeData();
out_tensor->ResetRefCount();
in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
out_tensor->set_data(in_tensor->data());
out_tensor->set_own_data(in_tensor->own_data());
output_ptr_ = reinterpret_cast<float *>(out_tensor->data());
}

auto ret = ParallelLaunch(this->ms_context_, ScatterNdUpdateRun, this, thread_n_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdate error error_code[" << ret << "]";
return RET_ERROR;
}

return RET_OK;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNdUpdate, LiteKernelCreator<ScatterNdUpdateCPUKernel>)
} // namespace mindspore::kernel

+ 49
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_UPDATE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_UPDATE_H_

#include <vector>
#include "src/inner_kernel.h"

namespace mindspore::kernel {

class ScatterNdUpdateCPUKernel : public InnerKernel {
public:
explicit ScatterNdUpdateCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
~ScatterNdUpdateCPUKernel() override = default;

int Init() override;
int ReSize() override;
int Run() override;
int ScatterNdUpdate(int task_id);

private:
int thread_n_num_ = 1;
int thread_n_stride_ = 1;
int num_unit_ = 1;
int unit_size_ = 1;
float *output_ptr_ = nullptr;
float *update_ptr_ = nullptr;
std::vector<int> out_strides_;
std::vector<int> output_unit_offsets_;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_UPDATE_H_

+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.cc View File

@@ -142,7 +142,7 @@ int WhereCPUKernel::RunWithTripleInputs() {
MS_LOG(ERROR) << "Error, inputs' length are zero !!!";
return RET_ERROR;
}
auto ret = ParallelLaunch(this->ms_context_, WhereRun, this, where_param_->thread_num_);
auto ret = ParallelLaunch(this->ms_context_, WhereRun, this, where_param_->op_parameter_.thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "WhereDwRun error: error_code[" << ret << "]";
return RET_ERROR;


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc View File

@@ -16,12 +16,12 @@

#include "tools/converter/parser/onnx/onnx_scatter_nd_parser.h"
#include <memory>
#include "ops/scatter_nd.h"
#include "ops/scatter_nd_update.h"

namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxScatterNdParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto prim = std::make_unique<ops::ScatterNd>();
auto prim = std::make_unique<ops::ScatterNdUpdate>();
return prim.release();
}



+ 1
- 1
mindspore/lite/tools/optimizer/format/delete_redundant_transpose.cc View File

@@ -69,7 +69,7 @@ STATUS DeleteRedundantTranspose::DeleteNot4DTranspose(const FuncGraphPtr &func_g
MS_LOG(ERROR) << "fetch transpose perm failed.";
return lite::RET_ERROR;
}
if (!shape.empty() && shape.size() != perm.size()) {
if (!shape.empty() && shape.size() != perm.size() && !(shape.size() == 1 && shape[0] == -1)) {
MS_LOG(DEBUG) << "transpose node need to be deleted.";
if (UpdateNodeFormat(func_graph, cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "update cnode format failed.";


Loading…
Cancel
Save