Browse Source

!15712 [MS][LITE]add some function for refactoring control flow

From: @mengyuanli
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
pull/15712/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
2823b43632
20 changed files with 316 additions and 30 deletions
  1. +29
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/partial_fusion_parameter.h
  2. +7
    -7
      mindspore/core/mindrt/include/actor/op_actor.h
  3. +16
    -0
      mindspore/lite/src/common/prim_util.cc
  4. +2
    -0
      mindspore/lite/src/common/prim_util.h
  5. +1
    -1
      mindspore/lite/src/common/tensor_util.cc
  6. +17
    -15
      mindspore/lite/src/lite_kernel.cc
  7. +35
    -0
      mindspore/lite/src/lite_kernel_util.cc
  8. +7
    -0
      mindspore/lite/src/lite_kernel_util.h
  9. +1
    -1
      mindspore/lite/src/lite_mindrt.cc
  10. +34
    -0
      mindspore/lite/src/ops/populate/call_populate.cc
  11. +1
    -4
      mindspore/lite/src/ops/populate/partial_populate.cc
  12. +38
    -0
      mindspore/lite/src/runtime/kernel/arm/base/call.cc
  13. +38
    -0
      mindspore/lite/src/runtime/kernel/arm/base/call.h
  14. +37
    -0
      mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.cc
  15. +38
    -0
      mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.h
  16. +3
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  17. +6
    -0
      mindspore/lite/src/sub_graph_kernel.cc
  18. +2
    -0
      mindspore/lite/src/sub_graph_kernel.h
  19. +2
    -0
      mindspore/lite/src/tensor.cc
  20. +2
    -0
      mindspore/lite/src/tensor.h

+ 29
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/partial_fusion_parameter.h View File

@@ -0,0 +1,29 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_NNACL_PARTIAL_FUSION_H_
#define MINDSPORE_NNACL_PARTIAL_FUSION_H_

#include "nnacl/op_base.h"
#include "nnacl/common_func.h"
#include "nnacl/nnacl_utils.h"

typedef struct PartialParameter {
OpParameter op_parameter_;
int sub_graph_index_;
} PartialParameter;

#endif // MINDSPORE_NNACL_ARTITHMETIC_H_

+ 7
- 7
mindspore/core/mindrt/include/actor/op_actor.h View File

@@ -54,7 +54,7 @@ using OpDataPtr = std::shared_ptr<OpData<T>>;
template <typename T>
struct OpContext {
uuids::uuid *sequential_num_;
std::vector<OpDataPtr<T>> *outputData_;
std::vector<OpDataPtr<T>> *output_data_;
std::vector<Promise<int>> *results_;
const void *kernel_call_back_before_;
const void *kernel_call_back_after_;
@@ -97,14 +97,14 @@ class OpActor : public ActorBase {
};

template <typename T>
Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &inputData, OpContext<T> *context) {
Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &input_data, OpContext<T> *context) {
std::list<Future<int>> futures;
for (auto promise : *(context->results_)) {
futures.push_back(promise.GetFuture());
}
Future<std::list<int>> collect = mindspore::Collect<int>(futures);

for (auto data : inputData) {
for (auto data : input_data) {
Async(data->op_id_, &mindspore::OpActor<T>::RunOpData, data, context);
}

@@ -112,18 +112,18 @@ Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &inputData
}

template <typename T>
int MindrtRun(const std::vector<OpDataPtr<T>> &inputData, std::vector<OpDataPtr<T>> *outputData,
int MindrtRun(const std::vector<OpDataPtr<T>> &input_data, std::vector<OpDataPtr<T>> *output_data,
const void *kernel_call_back_before, const void *kernel_call_back_after) {
OpContext<T> context;
std::vector<Promise<int>> promises(outputData->size());
std::vector<Promise<int>> promises(output_data->size());
uuids::uuid uid;
context.sequential_num_ = &uid;
context.results_ = &promises;
context.outputData_ = outputData;
context.output_data_ = output_data;
context.kernel_call_back_before_ = kernel_call_back_before;
context.kernel_call_back_after_ = kernel_call_back_after;

auto collect = MindrtAsyncRun<T>(inputData, &context);
auto collect = MindrtAsyncRun<T>(input_data, &context);
collect.Wait();
if (!collect.IsOK()) {
return -1;


+ 16
- 0
mindspore/lite/src/common/prim_util.cc View File

@@ -66,6 +66,22 @@ bool IsPartialNode(const void *primitive) {
return false;
}

bool IsCallNode(const void *primitive) {
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Call;
}
return false;
}

bool IsSwitchNode(const void *primitive) {
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Switch;
}
return false;
}

int GetPartialGraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1;


+ 2
- 0
mindspore/lite/src/common/prim_util.h View File

@@ -24,6 +24,8 @@ const char *PrimitiveTypeName(int type);
const char *PrimitiveCurVersionTypeName(int type);
int GenPrimVersionKey(int primitive_type, int schema_version);
bool IsPartialNode(const void *primitive);
bool IsCallNode(const void *node);
bool IsSwitchNode(const void *node);
int GetPartialGraphIndex(const void *primitive);
bool IsWhileNode(const void *primitive);
int GetWhileBodySubgraphIndex(const void *primitive);


+ 1
- 1
mindspore/lite/src/common/tensor_util.cc View File

@@ -220,7 +220,7 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) {
return RET_ERROR;
}
if (tensor->data_type() != kObjectTypeTensorType && tensor->data_c() == nullptr) {
MS_LOG(ERROR) << "Graph input tensor is nullptr " << tensors;
MS_LOG(ERROR) << "Graph input tensor data is nullptr " << tensor->tensor_name();
return RET_ERROR;
}
auto shape = tensor->shape();


+ 17
- 15
mindspore/lite/src/lite_kernel.cc View File

@@ -188,23 +188,25 @@ void LiteKernel::FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope
// clean io kernels
this->in_kernels_.clear();
this->out_kernels_.clear();
// find io kernels
for (auto *scope_kernel : scope_kernels) {
if (scope_kernel == this) {
continue;
}
for (auto *tensor : this->in_tensors_) {
if (lite::IsContain(scope_kernel->out_tensors(), tensor)) {
if (!lite::IsContain(this->in_kernels(), scope_kernel)) {
this->AddInKernel(scope_kernel);
}
// find io kernels, need optimize time
for (auto *tensor : this->in_tensors_) {
for (auto *scope_kernel : scope_kernels) {
if (scope_kernel == this) {
continue;
}
if (lite::IsContain(scope_kernel->out_tensors(), tensor) && !lite::IsContain(this->in_kernels(), scope_kernel)) {
this->AddInKernel(scope_kernel);
}
}
for (auto *tensor : this->out_tensors_) {
if (lite::IsContain(scope_kernel->in_tensors(), tensor)) {
if (!lite::IsContain(this->out_kernels(), scope_kernel)) {
this->AddOutKernel(scope_kernel);
}
}

for (auto *tensor : this->out_tensors_) {
for (auto *scope_kernel : scope_kernels) {
if (scope_kernel == this) {
continue;
}
if (lite::IsContain(scope_kernel->in_tensors(), tensor) && !lite::IsContain(this->out_kernels(), scope_kernel)) {
this->AddOutKernel(scope_kernel);
}
}
}


+ 35
- 0
mindspore/lite/src/lite_kernel_util.cc View File

@@ -17,6 +17,7 @@
#include "src/lite_kernel_util.h"
#include <queue>
#include <set>
#include "src/sub_graph_kernel.h"

namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
@@ -187,4 +188,38 @@ void LiteKernelUtil::InitTensorInitRefCount(const std::vector<kernel::LiteKernel

int LiteKernelUtil::SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs) { return -1; }

bool LiteKernelUtil::IsSwitchCall(kernel::LiteKernel *kernel) {
auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
if (subgraph_kernel == nullptr) {
return false;
}
for (auto &node : subgraph_kernel->nodes()) {
if (node->Type() == schema::PrimitiveType_Switch &&
InputsContainsSpecificNode(node, schema::PrimitiveType_PartialFusion) && node->out_kernels().size() == 1 &&
node->out_kernels().front()->Type() == schema::PrimitiveType_Call) {
return true;
}
}

return false;
}

kernel::LiteKernel *LiteKernelUtil::GetInputsSpecificNode(const kernel::LiteKernel *kernel,
const schema::PrimitiveType &primitive_type) {
for (auto input : kernel->in_kernels()) {
if (input->Type() == primitive_type) {
return input;
}
}
return nullptr;
}

bool LiteKernelUtil::InputsContainsSpecificNode(const kernel::LiteKernel *kernel,
const schema::PrimitiveType &primitive_type) {
if (GetInputsSpecificNode(kernel, primitive_type)) {
return true;
}
return false;
}

} // namespace mindspore::kernel

+ 7
- 0
mindspore/lite/src/lite_kernel_util.h View File

@@ -35,6 +35,13 @@ class LiteKernelUtil {
static void InitTensorInitRefCount(const std::vector<kernel::LiteKernel *> &kernels);

static int SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs);

static bool IsSwitchCall(kernel::LiteKernel *kernel);

static kernel::LiteKernel *GetInputsSpecificNode(const kernel::LiteKernel *kernel,
const schema::PrimitiveType &primitive_type);

static bool InputsContainsSpecificNode(const kernel::LiteKernel *kernel, const schema::PrimitiveType &primitive_type);
};

} // namespace mindspore::kernel


+ 1
- 1
mindspore/lite/src/lite_mindrt.cc View File

@@ -48,7 +48,7 @@ int LiteOpActor::CompileArrow() {

void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) {
for (auto op_arrow : output_op_arrows_) {
auto data = context->outputData_->at(op_arrow->from_output_index_);
auto data = context->output_data_->at(op_arrow->from_output_index_);
Async(op_arrow->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data, context);
}
return;


+ 34
- 0
mindspore/lite/src/ops/populate/call_populate.cc View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
using mindspore::schema::PrimitiveType_Call;

namespace mindspore {
namespace lite {
OpParameter *PopulateCallParameter(const void *prim) {
OpParameter *call_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (call_parameter == nullptr) {
MS_LOG(ERROR) << "malloc CallParameter failed.";
return nullptr;
}
memset(call_parameter, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
call_parameter->type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(call_parameter);
}
REG_POPULATE(PrimitiveType_Call, PopulateCallParameter, SCHEMA_CUR)
} // namespace lite
} // namespace mindspore

+ 1
- 4
mindspore/lite/src/ops/populate/partial_populate.cc View File

@@ -14,14 +14,11 @@
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/partial_fusion_parameter.h"
using mindspore::schema::PrimitiveType_PartialFusion;

namespace mindspore {
namespace lite {
typedef struct PartialParameter {
OpParameter op_parameter_;
int sub_graph_index_;
} PartialParameter;

OpParameter *PopulatePartialParameter(const void *prim) {
PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter)));


+ 38
- 0
mindspore/lite/src/runtime/kernel/arm/base/call.cc View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/base/call.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/tensorlist.h"
#include "src/common/utils.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Call;

// this file is useless when move create actor before schedule.
namespace mindspore::kernel {
int CallCPUKernel::Init() { return RET_OK; }
int CallCPUKernel::ReSize() { return RET_OK; }
int CallCPUKernel::Run() { return RET_OK; }

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>)
} // namespace mindspore::kernel

+ 38
- 0
mindspore/lite/src/runtime/kernel/arm/base/call.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CALL_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CALL_H_

#include <vector>
#include "src/runtime/kernel/arm/base/carry_data.h"
#include "src/tensor.h"
#include "src/tensorlist.h"

// this file is useless when move create actor before schedule.
namespace mindspore::kernel {
class CallCPUKernel : public LiteKernel {
public:
CallCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {}
~CallCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CALL_H_

+ 37
- 0
mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.cc View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/base/partial_fusion.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/tensorlist.h"
#include "src/common/utils.h"

// this file is going to be removed when move create actor before schedule.
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_PartialFusion;

namespace mindspore::kernel {
int PartialFusionKernel::Init() { return RET_OK; }
int PartialFusionKernel::ReSize() { return RET_OK; }
int PartialFusionKernel::Run() { return RET_OK; }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>)
} // namespace mindspore::kernel

+ 38
- 0
mindspore/lite/src/runtime/kernel/arm/base/partial_fusion.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PARTIAL_FUSION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PARTIAL_FUSION_H_

#include <vector>
#include "src/runtime/kernel/arm/base/carry_data.h"
#include "src/tensor.h"
#include "src/tensorlist.h"

// this file is going to be removed when move create actor before schedule.
namespace mindspore::kernel {
class PartialFusionKernel : public LiteKernel {
public:
PartialFusionKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {}
~PartialFusionKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PARTIAL_FUSION_H_

+ 3
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -72,7 +72,8 @@ int ArithmeticCPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type();
if (in0_dataType != in1_dataType) {
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same. input 0 dataType: "
<< in0_dataType << " input 1 dataType: " << in1_dataType;
return RET_ERROR;
}
return RET_OK;
@@ -408,7 +409,7 @@ int ArithmeticsRun(void *cdata, int task_id) {

int ArithmeticCPUKernel::Run() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed.";
MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed, kernel name: " << this->name();
return RET_ERROR;
}
if (!input0_broadcast_) {


+ 6
- 0
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -23,6 +23,7 @@
#include "src/common/version_manager.h"
#include "src/runtime/infer_manager.h"
#include "src/common/tensor_util.h"
#include "src/common/utils.h"

namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
@@ -141,6 +142,11 @@ void SubGraphKernel::InitOutTensorInitRefCount() {
node->InitOutTensorInitRefCount();
}
}
void SubGraphKernel::DropNode(LiteKernel *node) {
lite::VectorErase(&nodes_, node);
lite::VectorErase(&in_nodes_, node);
lite::VectorErase(&out_nodes_, node);
}

int CpuSubGraph::Prepare() {
auto ret = SubGraphKernel::Prepare();


+ 2
- 0
mindspore/lite/src/sub_graph_kernel.h View File

@@ -112,6 +112,8 @@ class SubGraphKernel : public LiteKernel {

std::vector<LiteKernel *> nodes() { return this->nodes_; }

void DropNode(LiteKernel *node);

protected:
std::vector<LiteKernel *> nodes_{};
// entry nodes in nodes


+ 2
- 0
mindspore/lite/src/tensor.cc View File

@@ -348,6 +348,8 @@ void *Tensor::MutableData() {
return this->data_;
}

void Tensor::IncRefCount() { ++ref_count_; }

void Tensor::DecRefCount() {
if (this->IsConst() || this->IsGraphInput()) {
return;


+ 2
- 0
mindspore/lite/src/tensor.h View File

@@ -145,6 +145,8 @@ class Tensor : public mindspore::tensor::MSTensor {

void ResetRefCount() { this->ref_count_ = this->init_ref_count_; }

void IncRefCount();

void DecRefCount();

std::string ToString() const;


Loading…
Cancel
Save