From: @mengyuanli Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tongpull/15712/MERGE
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_NNACL_PARTIAL_FUSION_H_ | |||||
| #define MINDSPORE_NNACL_PARTIAL_FUSION_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| #include "nnacl/common_func.h" | |||||
| #include "nnacl/nnacl_utils.h" | |||||
| typedef struct PartialParameter { | |||||
| OpParameter op_parameter_; | |||||
| int sub_graph_index_; | |||||
| } PartialParameter; | |||||
| #endif // MINDSPORE_NNACL_ARTITHMETIC_H_ | |||||
| @@ -54,7 +54,7 @@ using OpDataPtr = std::shared_ptr<OpData<T>>; | |||||
| template <typename T> | template <typename T> | ||||
| struct OpContext { | struct OpContext { | ||||
| uuids::uuid *sequential_num_; | uuids::uuid *sequential_num_; | ||||
| std::vector<OpDataPtr<T>> *outputData_; | |||||
| std::vector<OpDataPtr<T>> *output_data_; | |||||
| std::vector<Promise<int>> *results_; | std::vector<Promise<int>> *results_; | ||||
| const void *kernel_call_back_before_; | const void *kernel_call_back_before_; | ||||
| const void *kernel_call_back_after_; | const void *kernel_call_back_after_; | ||||
| @@ -97,14 +97,14 @@ class OpActor : public ActorBase { | |||||
| }; | }; | ||||
| template <typename T> | template <typename T> | ||||
| Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &inputData, OpContext<T> *context) { | |||||
| Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &input_data, OpContext<T> *context) { | |||||
| std::list<Future<int>> futures; | std::list<Future<int>> futures; | ||||
| for (auto promise : *(context->results_)) { | for (auto promise : *(context->results_)) { | ||||
| futures.push_back(promise.GetFuture()); | futures.push_back(promise.GetFuture()); | ||||
| } | } | ||||
| Future<std::list<int>> collect = mindspore::Collect<int>(futures); | Future<std::list<int>> collect = mindspore::Collect<int>(futures); | ||||
| for (auto data : inputData) { | |||||
| for (auto data : input_data) { | |||||
| Async(data->op_id_, &mindspore::OpActor<T>::RunOpData, data, context); | Async(data->op_id_, &mindspore::OpActor<T>::RunOpData, data, context); | ||||
| } | } | ||||
| @@ -112,18 +112,18 @@ Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &inputData | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| int MindrtRun(const std::vector<OpDataPtr<T>> &inputData, std::vector<OpDataPtr<T>> *outputData, | |||||
| int MindrtRun(const std::vector<OpDataPtr<T>> &input_data, std::vector<OpDataPtr<T>> *output_data, | |||||
| const void *kernel_call_back_before, const void *kernel_call_back_after) { | const void *kernel_call_back_before, const void *kernel_call_back_after) { | ||||
| OpContext<T> context; | OpContext<T> context; | ||||
| std::vector<Promise<int>> promises(outputData->size()); | |||||
| std::vector<Promise<int>> promises(output_data->size()); | |||||
| uuids::uuid uid; | uuids::uuid uid; | ||||
| context.sequential_num_ = &uid; | context.sequential_num_ = &uid; | ||||
| context.results_ = &promises; | context.results_ = &promises; | ||||
| context.outputData_ = outputData; | |||||
| context.output_data_ = output_data; | |||||
| context.kernel_call_back_before_ = kernel_call_back_before; | context.kernel_call_back_before_ = kernel_call_back_before; | ||||
| context.kernel_call_back_after_ = kernel_call_back_after; | context.kernel_call_back_after_ = kernel_call_back_after; | ||||
| auto collect = MindrtAsyncRun<T>(inputData, &context); | |||||
| auto collect = MindrtAsyncRun<T>(input_data, &context); | |||||
| collect.Wait(); | collect.Wait(); | ||||
| if (!collect.IsOK()) { | if (!collect.IsOK()) { | ||||
| return -1; | return -1; | ||||
| @@ -66,6 +66,22 @@ bool IsPartialNode(const void *primitive) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool IsCallNode(const void *primitive) { | |||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | |||||
| if (schema_version == SCHEMA_CUR) { | |||||
| return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Call; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsSwitchNode(const void *primitive) { | |||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | |||||
| if (schema_version == SCHEMA_CUR) { | |||||
| return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Switch; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| int GetPartialGraphIndex(const void *primitive) { | int GetPartialGraphIndex(const void *primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| int index = -1; | int index = -1; | ||||
| @@ -24,6 +24,8 @@ const char *PrimitiveTypeName(int type); | |||||
| const char *PrimitiveCurVersionTypeName(int type); | const char *PrimitiveCurVersionTypeName(int type); | ||||
| int GenPrimVersionKey(int primitive_type, int schema_version); | int GenPrimVersionKey(int primitive_type, int schema_version); | ||||
| bool IsPartialNode(const void *primitive); | bool IsPartialNode(const void *primitive); | ||||
| bool IsCallNode(const void *node); | |||||
| bool IsSwitchNode(const void *node); | |||||
| int GetPartialGraphIndex(const void *primitive); | int GetPartialGraphIndex(const void *primitive); | ||||
| bool IsWhileNode(const void *primitive); | bool IsWhileNode(const void *primitive); | ||||
| int GetWhileBodySubgraphIndex(const void *primitive); | int GetWhileBodySubgraphIndex(const void *primitive); | ||||
| @@ -220,7 +220,7 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (tensor->data_type() != kObjectTypeTensorType && tensor->data_c() == nullptr) { | if (tensor->data_type() != kObjectTypeTensorType && tensor->data_c() == nullptr) { | ||||
| MS_LOG(ERROR) << "Graph input tensor is nullptr " << tensors; | |||||
| MS_LOG(ERROR) << "Graph input tensor data is nullptr " << tensor->tensor_name(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto shape = tensor->shape(); | auto shape = tensor->shape(); | ||||
| @@ -188,23 +188,25 @@ void LiteKernel::FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope | |||||
| // clean io kernels | // clean io kernels | ||||
| this->in_kernels_.clear(); | this->in_kernels_.clear(); | ||||
| this->out_kernels_.clear(); | this->out_kernels_.clear(); | ||||
| // find io kernels | |||||
| for (auto *scope_kernel : scope_kernels) { | |||||
| if (scope_kernel == this) { | |||||
| continue; | |||||
| } | |||||
| for (auto *tensor : this->in_tensors_) { | |||||
| if (lite::IsContain(scope_kernel->out_tensors(), tensor)) { | |||||
| if (!lite::IsContain(this->in_kernels(), scope_kernel)) { | |||||
| this->AddInKernel(scope_kernel); | |||||
| } | |||||
| // find io kernels, need optimize time | |||||
| for (auto *tensor : this->in_tensors_) { | |||||
| for (auto *scope_kernel : scope_kernels) { | |||||
| if (scope_kernel == this) { | |||||
| continue; | |||||
| } | |||||
| if (lite::IsContain(scope_kernel->out_tensors(), tensor) && !lite::IsContain(this->in_kernels(), scope_kernel)) { | |||||
| this->AddInKernel(scope_kernel); | |||||
| } | } | ||||
| } | } | ||||
| for (auto *tensor : this->out_tensors_) { | |||||
| if (lite::IsContain(scope_kernel->in_tensors(), tensor)) { | |||||
| if (!lite::IsContain(this->out_kernels(), scope_kernel)) { | |||||
| this->AddOutKernel(scope_kernel); | |||||
| } | |||||
| } | |||||
| for (auto *tensor : this->out_tensors_) { | |||||
| for (auto *scope_kernel : scope_kernels) { | |||||
| if (scope_kernel == this) { | |||||
| continue; | |||||
| } | |||||
| if (lite::IsContain(scope_kernel->in_tensors(), tensor) && !lite::IsContain(this->out_kernels(), scope_kernel)) { | |||||
| this->AddOutKernel(scope_kernel); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "src/lite_kernel_util.h" | #include "src/lite_kernel_util.h" | ||||
| #include <queue> | #include <queue> | ||||
| #include <set> | #include <set> | ||||
| #include "src/sub_graph_kernel.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -187,4 +188,38 @@ void LiteKernelUtil::InitTensorInitRefCount(const std::vector<kernel::LiteKernel | |||||
| int LiteKernelUtil::SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs) { return -1; } | int LiteKernelUtil::SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs) { return -1; } | ||||
| bool LiteKernelUtil::IsSwitchCall(kernel::LiteKernel *kernel) { | |||||
| auto *subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel); | |||||
| if (subgraph_kernel == nullptr) { | |||||
| return false; | |||||
| } | |||||
| for (auto &node : subgraph_kernel->nodes()) { | |||||
| if (node->Type() == schema::PrimitiveType_Switch && | |||||
| InputsContainsSpecificNode(node, schema::PrimitiveType_PartialFusion) && node->out_kernels().size() == 1 && | |||||
| node->out_kernels().front()->Type() == schema::PrimitiveType_Call) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| kernel::LiteKernel *LiteKernelUtil::GetInputsSpecificNode(const kernel::LiteKernel *kernel, | |||||
| const schema::PrimitiveType &primitive_type) { | |||||
| for (auto input : kernel->in_kernels()) { | |||||
| if (input->Type() == primitive_type) { | |||||
| return input; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| bool LiteKernelUtil::InputsContainsSpecificNode(const kernel::LiteKernel *kernel, | |||||
| const schema::PrimitiveType &primitive_type) { | |||||
| if (GetInputsSpecificNode(kernel, primitive_type)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -35,6 +35,13 @@ class LiteKernelUtil { | |||||
| static void InitTensorInitRefCount(const std::vector<kernel::LiteKernel *> &kernels); | static void InitTensorInitRefCount(const std::vector<kernel::LiteKernel *> &kernels); | ||||
| static int SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs); | static int SetInput(const LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs); | ||||
| static bool IsSwitchCall(kernel::LiteKernel *kernel); | |||||
| static kernel::LiteKernel *GetInputsSpecificNode(const kernel::LiteKernel *kernel, | |||||
| const schema::PrimitiveType &primitive_type); | |||||
| static bool InputsContainsSpecificNode(const kernel::LiteKernel *kernel, const schema::PrimitiveType &primitive_type); | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -48,7 +48,7 @@ int LiteOpActor::CompileArrow() { | |||||
| void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) { | void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) { | ||||
| for (auto op_arrow : output_op_arrows_) { | for (auto op_arrow : output_op_arrows_) { | ||||
| auto data = context->outputData_->at(op_arrow->from_output_index_); | |||||
| auto data = context->output_data_->at(op_arrow->from_output_index_); | |||||
| Async(op_arrow->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data, context); | Async(op_arrow->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data, context); | ||||
| } | } | ||||
| return; | return; | ||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| using mindspore::schema::PrimitiveType_Call; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateCallParameter(const void *prim) { | |||||
| OpParameter *call_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (call_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc CallParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(call_parameter, 0, sizeof(OpParameter)); | |||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | |||||
| call_parameter->type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(call_parameter); | |||||
| } | |||||
| REG_POPULATE(PrimitiveType_Call, PopulateCallParameter, SCHEMA_CUR) | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -14,14 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/partial_fusion_parameter.h" | |||||
| using mindspore::schema::PrimitiveType_PartialFusion; | using mindspore::schema::PrimitiveType_PartialFusion; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| typedef struct PartialParameter { | |||||
| OpParameter op_parameter_; | |||||
| int sub_graph_index_; | |||||
| } PartialParameter; | |||||
| OpParameter *PopulatePartialParameter(const void *prim) { | OpParameter *PopulatePartialParameter(const void *prim) { | ||||
| PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter))); | PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter))); | ||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/call.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/tensorlist.h" | |||||
| #include "src/common/utils.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Call; | |||||
| // this file is useless when move create actor before schedule. | |||||
| namespace mindspore::kernel { | |||||
| int CallCPUKernel::Init() { return RET_OK; } | |||||
| int CallCPUKernel::ReSize() { return RET_OK; } | |||||
| int CallCPUKernel::Run() { return RET_OK; } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Call, LiteKernelCreator<CallCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CALL_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CALL_H_ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/base/carry_data.h" | |||||
| #include "src/tensor.h" | |||||
| #include "src/tensorlist.h" | |||||
| // this file is useless when move create actor before schedule. | |||||
| namespace mindspore::kernel { | |||||
| class CallCPUKernel : public LiteKernel { | |||||
| public: | |||||
| CallCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~CallCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CALL_H_ | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/partial_fusion.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/tensorlist.h" | |||||
| #include "src/common/utils.h" | |||||
| // this file is going to be removed when move create actor before schedule. | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_PartialFusion; | |||||
| namespace mindspore::kernel { | |||||
| int PartialFusionKernel::Init() { return RET_OK; } | |||||
| int PartialFusionKernel::ReSize() { return RET_OK; } | |||||
| int PartialFusionKernel::Run() { return RET_OK; } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_PartialFusion, LiteKernelCreator<PartialFusionKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PARTIAL_FUSION_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PARTIAL_FUSION_H_ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/base/carry_data.h" | |||||
| #include "src/tensor.h" | |||||
| #include "src/tensorlist.h" | |||||
| // this file is going to be removed when move create actor before schedule. | |||||
| namespace mindspore::kernel { | |||||
| class PartialFusionKernel : public LiteKernel { | |||||
| public: | |||||
| PartialFusionKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx) {} | |||||
| ~PartialFusionKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_PARTIAL_FUSION_H_ | |||||
| @@ -72,7 +72,8 @@ int ArithmeticCPUKernel::CheckDataType() { | |||||
| auto in0_dataType = in_tensors_.at(0)->data_type(); | auto in0_dataType = in_tensors_.at(0)->data_type(); | ||||
| auto in1_dataType = in_tensors_.at(1)->data_type(); | auto in1_dataType = in_tensors_.at(1)->data_type(); | ||||
| if (in0_dataType != in1_dataType) { | if (in0_dataType != in1_dataType) { | ||||
| MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same."; | |||||
| MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same. input 0 dataType: " | |||||
| << in0_dataType << " input 1 dataType: " << in1_dataType; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -408,7 +409,7 @@ int ArithmeticsRun(void *cdata, int task_id) { | |||||
| int ArithmeticCPUKernel::Run() { | int ArithmeticCPUKernel::Run() { | ||||
| if (CheckDataType() != RET_OK) { | if (CheckDataType() != RET_OK) { | ||||
| MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed."; | |||||
| MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed, kernel name: " << this->name(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!input0_broadcast_) { | if (!input0_broadcast_) { | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "src/common/version_manager.h" | #include "src/common/version_manager.h" | ||||
| #include "src/runtime/infer_manager.h" | #include "src/runtime/infer_manager.h" | ||||
| #include "src/common/tensor_util.h" | #include "src/common/tensor_util.h" | ||||
| #include "src/common/utils.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -141,6 +142,11 @@ void SubGraphKernel::InitOutTensorInitRefCount() { | |||||
| node->InitOutTensorInitRefCount(); | node->InitOutTensorInitRefCount(); | ||||
| } | } | ||||
| } | } | ||||
| void SubGraphKernel::DropNode(LiteKernel *node) { | |||||
| lite::VectorErase(&nodes_, node); | |||||
| lite::VectorErase(&in_nodes_, node); | |||||
| lite::VectorErase(&out_nodes_, node); | |||||
| } | |||||
| int CpuSubGraph::Prepare() { | int CpuSubGraph::Prepare() { | ||||
| auto ret = SubGraphKernel::Prepare(); | auto ret = SubGraphKernel::Prepare(); | ||||
| @@ -112,6 +112,8 @@ class SubGraphKernel : public LiteKernel { | |||||
| std::vector<LiteKernel *> nodes() { return this->nodes_; } | std::vector<LiteKernel *> nodes() { return this->nodes_; } | ||||
| void DropNode(LiteKernel *node); | |||||
| protected: | protected: | ||||
| std::vector<LiteKernel *> nodes_{}; | std::vector<LiteKernel *> nodes_{}; | ||||
| // entry nodes in nodes | // entry nodes in nodes | ||||
| @@ -348,6 +348,8 @@ void *Tensor::MutableData() { | |||||
| return this->data_; | return this->data_; | ||||
| } | } | ||||
| void Tensor::IncRefCount() { ++ref_count_; } | |||||
| void Tensor::DecRefCount() { | void Tensor::DecRefCount() { | ||||
| if (this->IsConst() || this->IsGraphInput()) { | if (this->IsConst() || this->IsGraphInput()) { | ||||
| return; | return; | ||||
| @@ -145,6 +145,8 @@ class Tensor : public mindspore::tensor::MSTensor { | |||||
| void ResetRefCount() { this->ref_count_ = this->init_ref_count_; } | void ResetRefCount() { this->ref_count_ = this->init_ref_count_; } | ||||
| void IncRefCount(); | |||||
| void DecRefCount(); | void DecRefCount(); | ||||
| std::string ToString() const; | std::string ToString() const; | ||||