| @@ -109,7 +109,7 @@ checkopts() | |||||
| ENABLE_GPU="off" | ENABLE_GPU="off" | ||||
| # Process the options | # Process the options | ||||
| while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:En' opt | |||||
| while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:EnT:' opt | |||||
| do | do | ||||
| OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | ||||
| case "${opt}" in | case "${opt}" in | ||||
| @@ -282,6 +282,11 @@ checkopts() | |||||
| ENABLE_IBVERBS="on" | ENABLE_IBVERBS="on" | ||||
| echo "enable IBVERBS for parameter server" | echo "enable IBVERBS for parameter server" | ||||
| ;; | ;; | ||||
| T) | |||||
| check_on_off $OPTARG T | |||||
| SUPPORT_TRAIN=$OPTARG | |||||
| echo "support train on device " | |||||
| ;; | |||||
| *) | *) | ||||
| echo "Unknown option ${opt}!" | echo "Unknown option ${opt}!" | ||||
| usage | usage | ||||
| @@ -23,6 +23,7 @@ endif() | |||||
| if (SUPPORT_TRAIN) | if (SUPPORT_TRAIN) | ||||
| set(ANF_SRC | set(ANF_SRC | ||||
| ${ANF_SRC} | |||||
| # ${CCSRC_DIR}/common/trans.cc | # ${CCSRC_DIR}/common/trans.cc | ||||
| # ${CCSRC_DIR}/utils/lite/base_ref_utils.cc | # ${CCSRC_DIR}/utils/lite/base_ref_utils.cc | ||||
| # ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc | # ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc | ||||
| @@ -40,14 +41,17 @@ if (SUPPORT_TRAIN) | |||||
| set(LITE_SRC | set(LITE_SRC | ||||
| ${LITE_SRC} | ${LITE_SRC} | ||||
| ${ANF_SRC} | ${ANF_SRC} | ||||
| ${PASS_SRC} | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc | |||||
| # ${PASS_SRC} | |||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc | |||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc | |||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc | |||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc | |||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc | |||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary | |||||
| ) | ) | ||||
| else () | else () | ||||
| set(LITE_SRC | set(LITE_SRC | ||||
| ${LITE_SRC} | ${LITE_SRC} | ||||
| @@ -27,96 +27,143 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| void AnfImporterFromMetaGraph::ConverterConstTensor() { | void AnfImporterFromMetaGraph::ConverterConstTensor() { | ||||
| MS_EXCEPTION_IF_NULL(model); | |||||
| auto *meta_graph = model->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(model_); | |||||
| auto *meta_graph = model_->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(meta_graph); | MS_EXCEPTION_IF_NULL(meta_graph); | ||||
| for (size_t i = 0; i < meta_graph->allTensors()->size(); i++) { | |||||
| num_of_tensors_ = meta_graph->allTensors()->size(); | |||||
| for (size_t i = 0; i < num_of_tensors_; i++) { | |||||
| auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i); | auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| if (tensor->nodeType() != schema::NodeType_ValueNode) { | |||||
| if ((tensor->nodeType() != schema::NodeType_ValueNode) && (tensor->nodeType() != schema::NodeType_Parameter)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_ASSERT(tensor->dims() != nullptr); | MS_ASSERT(tensor->dims() != nullptr); | ||||
| auto parameter = model->add_parameter(); | |||||
| auto parameter = model_->add_parameter(); | |||||
| std::vector<int> shape; | std::vector<int> shape; | ||||
| for (size_t j = 0; j < tensor->dims()->size(); ++j) { | for (size_t j = 0; j < tensor->dims()->size(); ++j) { | ||||
| shape.push_back(tensor->dims()->data()[j]); | shape.push_back(tensor->dims()->data()[j]); | ||||
| } | } | ||||
| auto type_id = static_cast<TypeId>(tensor->dataType()); | |||||
| auto type_id = static_cast<TypeId>(tensor->dataType()); // todo: check error | |||||
| auto type_ptr = TypeIdToType(type_id); | auto type_ptr = TypeIdToType(type_id); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| auto abstractBase = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| // XXX TODO copy format | |||||
| parameter->set_abstract(abstractBase); | |||||
| parameter->set_name(std::string("Parameter")); | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| MS_EXCEPTION_IF_NULL(param_value); | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_tensor_type(type_id); | |||||
| if (tensor->data() != nullptr) { | |||||
| auto size = tensor->data()->size(); | |||||
| char *tensor_data = new char[size](); | |||||
| std::memcpy(tensor_data, tensor->data()->data(), size); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| param_value->set_tensor_size(size); | |||||
| if (tensor->nodeType() == schema::NodeType_ValueNode) { | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| MS_EXCEPTION_IF_NULL(param_value); | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_tensor_type(type_id); | |||||
| if (tensor->data() != nullptr) { | |||||
| auto size = tensor->data()->size(); | |||||
| char *tensor_data = new char[size](); | |||||
| std::memcpy(tensor_data, tensor->data()->data(), size); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| param_value->set_tensor_size(size); | |||||
| } | |||||
| parameter->set_default_param(param_value); | |||||
| } | } | ||||
| parameter->set_default_param(param_value); | |||||
| AddNode(i, parameter); | AddNode(i, parameter); | ||||
| model_->AddAnfNode(i, parameter); | |||||
| } | } | ||||
| } | } | ||||
| int AnfImporterFromMetaGraph::ConverterCNode() { | int AnfImporterFromMetaGraph::ConverterCNode() { | ||||
| MS_EXCEPTION_IF_NULL(model); | |||||
| auto *meta_graph = model->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(model_); | |||||
| auto *meta_graph = model_->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(meta_graph); | MS_EXCEPTION_IF_NULL(meta_graph); | ||||
| auto cNodes = meta_graph->nodes(); | |||||
| for (size_t i = 0; i < cNodes->size(); i++) { | |||||
| auto cNode = cNodes->GetAs<schema::CNode>(i); | |||||
| MS_EXCEPTION_IF_NULL(cNode); | |||||
| auto tensor_id = cNode->outputIndex()->data()[0]; | |||||
| if (GetNode(tensor_id)) { | |||||
| continue; | |||||
| } | |||||
| auto prim = std::make_shared<PrimitiveValue>(model->GetOp(cNode->name()->str())); | |||||
| // Crate CNode -- Order of inputs is as follows | |||||
| // First input should be the Primitive | |||||
| // Then we have CNodes that contribute to this CNode | |||||
| // Finally we Have the parameters | |||||
| // first itteration -- create CNode with primitive, create originator map | |||||
| for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { | |||||
| auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i); | |||||
| MS_EXCEPTION_IF_NULL(cNode); | |||||
| auto prim = std::make_shared<PrimitiveValue>(model_->GetOp(cNode->name()->str())); | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; | MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto value_node = NewValueNode(prim); | auto value_node = NewValueNode(prim); | ||||
| AddNode(tensor_id, value_node); | |||||
| // auto prim_name = std::string("PrimitivePy: ") + std::string(cNode->name()->c_str()); | |||||
| // value_node->set_fullname_with_scope(prim_name); | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node}; | std::vector<AnfNodePtr> op_inputs = {value_node}; | ||||
| auto cnode = model_->NewCNode(op_inputs); | |||||
| auto node_name = std::string(cNode->name()->c_str()) + std::to_string(i); | |||||
| cnode->set_fullname_with_scope(node_name); | |||||
| AddNode(num_of_tensors_ + i, cnode); | |||||
| for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { | |||||
| int tensor_id = cNode->outputIndex()->data()[j]; | |||||
| originator_[tensor_id] = cnode; | |||||
| } | |||||
| } | |||||
| // second itteration -- fill in input CNodes and Parameters | |||||
| // populate map | |||||
| for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { | |||||
| std::vector<int> input; | |||||
| std::vector<int> output; | |||||
| int tensor_id; | |||||
| auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i); | |||||
| MS_EXCEPTION_IF_NULL(cNode); | |||||
| auto cnode = std::dynamic_pointer_cast<CNode>(GetNode(num_of_tensors_ + i)); | |||||
| for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { | |||||
| tensor_id = cNode->outputIndex()->data()[j]; | |||||
| output.push_back(tensor_id); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(cNode->inputIndex()); | MS_EXCEPTION_IF_NULL(cNode->inputIndex()); | ||||
| for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { | for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { | ||||
| auto node = GetNode(*(cNode->inputIndex()->GetAs<uint32_t>(j))); | |||||
| if (nullptr == node) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_ERROR; | |||||
| tensor_id = cNode->inputIndex()->data()[j]; | |||||
| input.push_back(tensor_id); | |||||
| auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if ((tensor->nodeType() == schema::NodeType_Parameter) && (originator_[tensor_id] != nullptr)) { | |||||
| cnode->add_input(originator_[tensor_id]); | |||||
| } | } | ||||
| // todo: CheckInputNodeType, the first node should be op; | |||||
| op_inputs.push_back(node); | |||||
| } | } | ||||
| auto cnode = model->NewCNode(op_inputs); | |||||
| auto node_name = std::string(cNode->name()->c_str()); | |||||
| cnode->set_fullname_with_scope(node_name); | |||||
| AddNode(tensor_id, cnode); | |||||
| // finally add all the Parameters (which are ValueNodes) | |||||
| for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { | |||||
| tensor_id = cNode->inputIndex()->data()[j]; | |||||
| auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if ((tensor->nodeType() == schema::NodeType_ValueNode) && (GetNode(tensor_id) != nullptr)) { | |||||
| cnode->add_input(GetNode(tensor_id)); | |||||
| } | |||||
| } | |||||
| model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void AnfImporterFromMetaGraph::AddReturnCNode() { | void AnfImporterFromMetaGraph::AddReturnCNode() { | ||||
| MS_EXCEPTION_IF_NULL(model); | |||||
| auto *meta_graph = model->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(model_); | |||||
| auto *meta_graph = model_->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(meta_graph); | MS_EXCEPTION_IF_NULL(meta_graph); | ||||
| std::vector<int> input; | |||||
| std::vector<int> output; | |||||
| std::vector<AnfNodePtr> op_inputs; | std::vector<AnfNodePtr> op_inputs; | ||||
| auto value_node = NewValueNode(prim::kPrimReturn); | auto value_node = NewValueNode(prim::kPrimReturn); | ||||
| // value_node->set_fullname_with_scope("Primitive"); | |||||
| op_inputs.push_back(value_node); | op_inputs.push_back(value_node); | ||||
| auto tensor_id = meta_graph->outputIndex()->data()[0]; | |||||
| op_inputs.push_back(GetNode(tensor_id)); | |||||
| auto cnode = model->NewCNode(op_inputs); | |||||
| for (int i = 0; i < meta_graph->outputIndex()->size(); i++) { | |||||
| auto prev_cnode = originator_[meta_graph->outputIndex()->data()[i]]; | |||||
| if (prev_cnode != nullptr) op_inputs.push_back(prev_cnode); | |||||
| input.push_back(meta_graph->outputIndex()->data()[i]); | |||||
| } | |||||
| auto cnode = model_->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("return"); | cnode->set_fullname_with_scope("return"); | ||||
| model->set_return(cnode); | |||||
| model_->set_return(cnode); | |||||
| model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); | |||||
| } | } | ||||
| FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model; } | |||||
| FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model_; } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ | #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "src/train/model_impl.h" | #include "src/train/model_impl.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/common/anf_importer/anf_importer.h" | #include "src/common/anf_importer/anf_importer.h" | ||||
| @@ -25,7 +26,7 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfImporterFromMetaGraph : public AnfImporter { | class AnfImporterFromMetaGraph : public AnfImporter { | ||||
| public: | public: | ||||
| explicit AnfImporterFromMetaGraph(std::shared_ptr<ModelImpl> model) : model(model) {} | |||||
| explicit AnfImporterFromMetaGraph(std::shared_ptr<ModelImpl> model) : model_(model) {} | |||||
| ~AnfImporterFromMetaGraph() override = default; | ~AnfImporterFromMetaGraph() override = default; | ||||
| @@ -39,9 +40,10 @@ class AnfImporterFromMetaGraph : public AnfImporter { | |||||
| void AddReturnCNode() override; | void AddReturnCNode() override; | ||||
| private: | private: | ||||
| std::shared_ptr<ModelImpl> model = nullptr; | |||||
| std::shared_ptr<ModelImpl> model_ = nullptr; | |||||
| std::map<int, AnfNodePtr> originator_; | |||||
| int num_of_tensors_ = 0; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ | #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ | ||||
| @@ -60,7 +60,7 @@ class LiteKernel { | |||||
| explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | ||||
| const lite::Primitive *primitive) | const lite::Primitive *primitive) | ||||
| : opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive), | |||||
| : opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), | |||||
| context_(ctx) { | context_(ctx) { | ||||
| this->in_kernel_.clear(); | this->in_kernel_.clear(); | ||||
| this->out_kernel_.clear(); | this->out_kernel_.clear(); | ||||
| @@ -136,7 +136,7 @@ class LiteKernel { | |||||
| std::vector<lite::tensor::Tensor *> outputs_; | std::vector<lite::tensor::Tensor *> outputs_; | ||||
| std::vector<LiteKernel *> in_kernel_; | std::vector<LiteKernel *> in_kernel_; | ||||
| std::vector<LiteKernel *> out_kernel_; | std::vector<LiteKernel *> out_kernel_; | ||||
| bool train_mode; | |||||
| bool train_mode = false; | |||||
| bool need_reinit = false; | bool need_reinit = false; | ||||
| }; | }; | ||||
| @@ -14,11 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifdef SUPPORT_TRAIN | |||||
| #include "src/train/model_impl.h" | |||||
| #else | |||||
| // #ifdef SUPPORT_TRAIN | |||||
| // #include "src/train/model_impl.h" | |||||
| // #else | |||||
| #include "src/model_impl.h" | #include "src/model_impl.h" | ||||
| #endif | |||||
| // #endif | |||||
| #include "include/model.h" | #include "include/model.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -10,6 +10,13 @@ file(GLOB KERNEL_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc | ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc | ||||
| ) | ) | ||||
| if (SUPPORT_TRAIN) | |||||
| file (GLOB TRAIN_KERNEL_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/nnacl/fp32_grad/*.cc | |||||
| ) | |||||
| endif() | |||||
| if (PLATFORM_ARM64) | if (PLATFORM_ARM64) | ||||
| # assembly | # assembly | ||||
| file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s | file(GLOB ASSEMBLY_SRC nnacl/assembly/arm64/*.s | ||||
| @@ -27,5 +34,5 @@ if (PLATFORM_ARM32) | |||||
| set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | ||||
| endif() | endif() | ||||
| add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC}) | |||||
| add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC}) | |||||
| add_subdirectory(nnacl) | add_subdirectory(nnacl) | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/activation_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/activation_grad.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| @@ -102,6 +102,8 @@ kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector<lite::t | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -48,4 +48,4 @@ class ActivationGradCPUKernel : public LiteKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| @@ -16,9 +16,9 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32/arithmetic_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -88,4 +88,4 @@ class ArithmeticGradCPUKernel : public LiteKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ARITHMETIC_GRAD_H_ | |||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/runtime/kernel/arm/fp32/bias_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/bias_grad.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -18,8 +18,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_factory.h" | #include "src/kernel_factory.h" | ||||
| #include "src/runtime/kernel/arm/fp32/bngrad_input.h" | |||||
| #include "src/runtime//kernel/arm/nnacl/batch_norm.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -54,10 +54,6 @@ int BNGradInputCPUKernel::Init() { | |||||
| int BNGradInputCPUKernel::ReSize() { return RET_OK; } | int BNGradInputCPUKernel::ReSize() { return RET_OK; } | ||||
| /* | |||||
| according to https://wiseodd.github.io/techblog/2016/07/04/batchnorm | |||||
| */ | |||||
| int BNGradInputCPUKernel::Run() { | int BNGradInputCPUKernel::Run() { | ||||
| // std::cout << "run succ" << std::endl; | // std::cout << "run succ" << std::endl; | ||||
| auto *input_x = inputs_.at(0); | auto *input_x = inputs_.at(0); | ||||
| @@ -107,6 +103,8 @@ kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector<lite::tens | |||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | ||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -39,4 +39,4 @@ class BNGradInputCPUKernel : public LiteKernel { | |||||
| int workspace_size; | int workspace_size; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_ | |||||
| @@ -14,11 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_grad_filter.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/pack.h" | #include "src/runtime/kernel/arm/nnacl/pack.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/pack_ext.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -39,4 +39,4 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ | |||||
| @@ -14,11 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_grad_input.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h" | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/pack.h" | #include "src/runtime/kernel/arm/nnacl/pack.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/pack_ext.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -39,4 +39,4 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H | |||||
| @@ -17,7 +17,7 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32/opt_momentum.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/opt_momentum.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -14,11 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/pooling_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" | #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -48,4 +48,4 @@ class PoolingGradCPUKernel : public LiteKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/power_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/power_grad.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -47,4 +47,4 @@ class PowerGradCPUKernel : public LiteKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_ | |||||
| @@ -14,13 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/softmax_parameter.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| @@ -73,7 +72,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *la | |||||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | ||||
| auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data()); | auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data()); | ||||
| auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data()); | auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data()); | ||||
| auto out = reinterpret_cast<float *>(outputs_.at(0)->Data()); | |||||
| auto out = reinterpret_cast<float *>(outputs_.at(1)->Data()); | |||||
| float *grads = NULL; | float *grads = NULL; | ||||
| if (is_train()) { // outputs_.size() > 1) | if (is_train()) { // outputs_.size() > 1) | ||||
| grads = reinterpret_cast<float *>(outputs_.at(0)->Data()); | grads = reinterpret_cast<float *>(outputs_.at(0)->Data()); | ||||
| @@ -90,10 +89,11 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||||
| SoftmaxParameter sm_params; | SoftmaxParameter sm_params; | ||||
| sm_params.n_dim_ = param->n_dim_; | sm_params.n_dim_ = param->n_dim_; | ||||
| sm_params.element_size_ = data_size; | sm_params.element_size_ = data_size; | ||||
| sm_params.axis_ = 1; | |||||
| sm_params.axis_ = 0; | |||||
| for (int i = 0; i < 4; i++) // softmax has only 4 params in shape | for (int i = 0; i < 4; i++) // softmax has only 4 params in shape | ||||
| sm_params.input_shape_[i] = param->input_shape_[i]; | sm_params.input_shape_[i] = param->input_shape_[i]; | ||||
| float sum_data[sm_params.input_shape_[sm_params.axis_]]; | |||||
| float sum_data[sm_params.input_shape_[sm_params.axis_]] = {0}; | |||||
| std::fill(sum_data, sum_data + sm_params.input_shape_[sm_params.axis_], 0); | |||||
| Softmax(ins, losses, sum_data, &sm_params); | Softmax(ins, losses, sum_data, &sm_params); | ||||
| if (is_train()) { | if (is_train()) { | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/softmax_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/softmax_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" | #include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -30,8 +30,7 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel { | |||||
| explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, | explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, | ||||
| const std::vector<lite::tensor::Tensor *> &inputs, | const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, | const std::vector<lite::tensor::Tensor *> &outputs, | ||||
| const lite::Context *ctx, | |||||
| const lite::Primitive *primitive) | |||||
| const lite::Context *ctx, const lite::Primitive *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | ||||
| param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter); | param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter); | ||||
| } | } | ||||
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| #include <math.h> | |||||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" | |||||
| #include "src/runtime/kernel/arm/opclib/errorcode.h" | |||||
| struct ActivationGradParameter { | |||||
| OpParameter op_parameter{}; | |||||
| int type_; | |||||
| float alpha_{0.01}; | |||||
| }; | |||||
| inline int ReluGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = src1[i] > 0 ? 1.0f : 0.0f; | |||||
| } | |||||
| ElementMul(src0, dst, dst, length); | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int Relu6Grad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| if (src1[i] < 0) { | |||||
| dst[i] = 0; | |||||
| } else { | |||||
| dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f; | |||||
| } | |||||
| } | |||||
| ElementMul(src0, dst, dst, length); | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = src1[i] > 0.0f ? 1.0f : alpha; | |||||
| } | |||||
| ElementMul(src0, dst, dst, length); | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int SigmoidGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int TanhGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int HSwishGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); | |||||
| dst[i] = tmp * src0[i]; | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int HSigmoidGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); | |||||
| dst[i] = tmp * src0[i]; | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/arithmetic_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/arithmetic_grad.h" | |||||
| void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { | void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { | ||||
| for (int i = 0; i < element_size; i++) { | for (int i = 0; i < element_size; i++) { | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ARITHMETIC_GRAD_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ | |||||
| void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); | void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); | ||||
| void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); | void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/runtime/kernel/arm/nnacl/batch_norm.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h" | |||||
| static void sumSpatialBatch(const float *in, int size, int ch, float *out) { | static void sumSpatialBatch(const float *in, int size, int ch, float *out) { | ||||
| std::fill(out, out + ch, 0.f); | std::fill(out, out + ch, 0.f); | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ | |||||
| #define MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_ | |||||
| struct bnParameter { | struct bnParameter { | ||||
| int batch; | int batch; | ||||
| @@ -36,4 +36,5 @@ void meanAdd(const float *x, const float *mean, const float *variance_delta, int | |||||
| void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, | void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, | ||||
| const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta); | const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta); | ||||
| #endif | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_BATCH_NORM_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/gemm.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/gemm.h" | |||||
| static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, | static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, | ||||
| int ldc) { | int ldc) { | ||||
| @@ -14,10 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_ | |||||
| void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, | void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, | ||||
| int ldb, float beta, float *mat_c, int ldc); | int ldb, float beta, float *mat_c, int ldc); | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GEMM_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_GEMM_H_ | |||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include <string.h> | #include <string.h> | ||||
| #include "src/runtime/kernel/arm/nnacl/pack_ext.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/pack_ext.h" | |||||
| static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } | static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <cstdint> | #include <cstdint> | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h" | |||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { | void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { | ||||
| int stride_w = pooling_param->stride_w_; | int stride_w = pooling_param->stride_w_; | ||||
| @@ -14,12 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_ | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" | #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" | ||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); | void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); | ||||
| void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); | void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_POOLING_GRAD_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_GRAD_POOLING_GRAD_H_ | |||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include <cstddef> | #include <cstddef> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce_grad.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32_grad/reduce_grad.h" | |||||
| static inline bool NextIndex(const int num_dims, const int *dims, int *current) { | static inline bool NextIndex(const int num_dims, const int *dims, int *current) { | ||||
| int carry = 1; | int carry = 1; | ||||
| @@ -57,4 +57,3 @@ std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefTo | |||||
| return multiTensor; | return multiTensor; | ||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,16 +16,15 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "base/base_ref.h" | |||||
| #include "utils/base_ref.h" | |||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| #ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H | |||||
| #define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H | |||||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ | |||||
| #define MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref); | std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref); | ||||
| std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor( | std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor( | ||||
| const VectorRef &vector_ref); | const VectorRef &vector_ref); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H | |||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ | |||||
| @@ -14,7 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/src/train/lite_kernel_runtime.h" | |||||
| #include "src/train/lite_kernel_runtime.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) { | std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) { | ||||
| std::vector<CNodePtr> graph_inputs; | std::vector<CNodePtr> graph_inputs; | ||||
| @@ -34,7 +35,8 @@ std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<C | |||||
| } | } | ||||
| void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, | void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, | ||||
| const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||||
| const std::vector<tensor::Tensor *> &inputs, | |||||
| std::vector<tensor::Tensor *> *outputs) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto execution_order = graph->execution_order(); | auto execution_order = graph->execution_order(); | ||||
| auto graph_inputs = GetGraphInputs(execution_order); | auto graph_inputs = GetGraphInputs(execution_order); | ||||
| @@ -56,15 +58,17 @@ void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, | |||||
| auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input)); | auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input)); | ||||
| auto output_tensors = liteKernel->GetOutputs(); | auto output_tensors = liteKernel->GetOutputs(); | ||||
| for (auto output_tensor : output_tensors) { | for (auto output_tensor : output_tensors) { | ||||
| tensor::TensorPtr output_tensor_ptr(output_tensor); | |||||
| outputs->push_back(output_tensor_ptr); | |||||
| // tensor::TensorPtr output_tensor_ptr(output_tensor); | |||||
| outputs->push_back(output_tensor); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) { | |||||
| bool LiteInferKernelRuntime::Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs, | |||||
| std::vector<tensor::Tensor *> *outputs) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| BindInputOutput(graph, inputs, *outputs); | |||||
| std::vector<kernel::LiteKernel *> kernels; | std::vector<kernel::LiteKernel *> kernels; | ||||
| auto nodes = graph->execution_order(); | auto nodes = graph->execution_order(); | ||||
| for (const auto &node : nodes) { | for (const auto &node : nodes) { | ||||
| @@ -76,8 +80,7 @@ bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) { | |||||
| } | } | ||||
| kernel::LiteKernelUtil::TopologicalSortKernels(kernels); | kernel::LiteKernelUtil::TopologicalSortKernels(kernels); | ||||
| Executor executor; | Executor executor; | ||||
| auto ret = executor.Run(kernels); | |||||
| auto ret = executor.Run(inputs, *outputs, kernels); | |||||
| return 0 == ret; | return 0 == ret; | ||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -23,35 +23,28 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "src/runtime/allocator.h" | #include "src/runtime/allocator.h" | ||||
| #include "src/executor.h" | #include "src/executor.h" | ||||
| #include "runtime/device/kernel_runtime.h" | |||||
| // #include "runtime/device/kernel_runtime.h" | |||||
| #include "runtime/device/device_address.h" | #include "runtime/device/device_address.h" | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "backend/session/kernel_graph.h" | #include "backend/session/kernel_graph.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class LiteInferKernelRuntime : public device::KernelRuntime { | |||||
| class LiteInferKernelRuntime { | |||||
| public: | public: | ||||
| LiteInferKernelRuntime() = default; | LiteInferKernelRuntime() = default; | ||||
| ~LiteInferKernelRuntime() override = default; | |||||
| ~LiteInferKernelRuntime() = default; | |||||
| bool Init() override { return true; } | |||||
| void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, | |||||
| VectorRef *outputs); | |||||
| bool Run(session::KernelGraph *graph); | |||||
| bool Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs, | |||||
| std::vector<tensor::Tensor *> *outputs); | |||||
| void AssignKernelAddress(session::KernelGraph *graph) {} | void AssignKernelAddress(session::KernelGraph *graph) {} | ||||
| protected: | protected: | ||||
| void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs, | |||||
| std::vector<tensor::Tensor *> *outputs); | |||||
| std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order); | std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order); | ||||
| bool SyncStream() override { return true; }; | |||||
| device::DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||||
| TypeId type_id) override { | |||||
| return nullptr; | |||||
| }; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ | #endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ | ||||
| @@ -16,11 +16,34 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "src/train/model_impl.h" | #include "src/train/model_impl.h" | ||||
| #include "schema/model_generated.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "schema/model_generated.h" | |||||
| #include "src/common/anf_importer/import_from_meta_graph.h" | |||||
| namespace mindspore::lite::train { | namespace mindspore::lite::train { | ||||
| std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(model_buf); | |||||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | |||||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | |||||
| return nullptr; | |||||
| } | |||||
| // todo hangangqiang remove when copy primitive done | |||||
| auto *inner_buf = new char[size]; | |||||
| memcpy(inner_buf, model_buf, size); | |||||
| auto meta_graph = schema::GetMetaGraph(inner_buf); | |||||
| auto func_graph_model = std::make_shared<ModelImpl>(meta_graph); | |||||
| auto ret = func_graph_model->BuildOps(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(ERROR) << "BuildOps failed"; | |||||
| return nullptr; | |||||
| } | |||||
| AnfImporterFromMetaGraph anfImporter(func_graph_model); | |||||
| anfImporter.Import(); | |||||
| return func_graph_model; | |||||
| } | |||||
| const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { | const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { | ||||
| auto iter = ops.find(name); | auto iter = ops.find(name); | ||||
| if (iter == ops.end()) { | if (iter == ops.end()) { | ||||
| @@ -98,6 +121,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { | |||||
| return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim)); | return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim)); | ||||
| case schema::PrimitiveType_Nhwc2Nchw: | case schema::PrimitiveType_Nhwc2Nchw: | ||||
| return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim)); | return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim)); | ||||
| case schema::PrimitiveType_MatMul: | |||||
| return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim)); | |||||
| default: | default: | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -115,5 +140,6 @@ int ModelImpl::BuildOps() { | |||||
| auto srcPrim = cNode->primitive(); | auto srcPrim = cNode->primitive(); | ||||
| this->ops[name] = CopyPrimitive(srcPrim); | this->ops[name] = CopyPrimitive(srcPrim); | ||||
| } | } | ||||
| return 0; | |||||
| } | } | ||||
| } // namespace mindspore::lite::train | } // namespace mindspore::lite::train | ||||
| @@ -15,11 +15,12 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ | #ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ | ||||
| #define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H | |||||
| #define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/ops/ops.h" | #include "src/ops/ops.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -28,7 +29,7 @@ namespace mindspore::lite { | |||||
| namespace train { | namespace train { | ||||
| class ModelImpl : public FuncGraph { | class ModelImpl : public FuncGraph { | ||||
| public: | public: | ||||
| static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); | |||||
| static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); // { return NULL; }; | |||||
| ModelImpl() = default; | ModelImpl() = default; | ||||
| explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} | explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} | ||||
| ~ModelImpl() override = default; | ~ModelImpl() override = default; | ||||
| @@ -37,16 +38,27 @@ class ModelImpl : public FuncGraph { | |||||
| void FreeMetaGraph(); | void FreeMetaGraph(); | ||||
| int BuildOps(); | int BuildOps(); | ||||
| void AddCNodeInputOutput(std::string name, const std::vector<int> &input, const std::vector<int> &output) { | |||||
| std::vector<int> *tuple = new std::vector<int>[2]; | |||||
| tuple[0] = input; | |||||
| tuple[1] = output; | |||||
| connectivity_[name] = tuple; | |||||
| } | |||||
| std::vector<int> *GetCNodeInputOutputIndices(std::string name) { return connectivity_[name]; } | |||||
| void AddAnfNode(int id, AnfNodePtr anf_ptr) { tensors_[id] = anf_ptr; } | |||||
| AnfNodePtr GetAnfNode(int id) { return tensors_[id]; } | |||||
| protected: | protected: | ||||
| lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); | lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); | ||||
| protected: | protected: | ||||
| const schema::MetaGraph *meta_graph = nullptr; | const schema::MetaGraph *meta_graph = nullptr; | ||||
| std::map<int, AnfNodePtr> tensors_; | |||||
| std::map<std::string, std::vector<int> *> connectivity_; | |||||
| std::map<std::string, lite::Primitive *> ops; | std::map<std::string, lite::Primitive *> ops; | ||||
| }; | }; | ||||
| } // namespace train | } // namespace train | ||||
| using ModelImpl = mindspore::lite::train::ModelImpl; | using ModelImpl = mindspore::lite::train::ModelImpl; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_INCLUDE_MODEL_H | |||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ | |||||
| @@ -0,0 +1,253 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include "src/train/train_anf_session.h" | |||||
| #include "include/context.h" | |||||
| #include "mindspore/ccsrc/runtime/device/kernel_info.h" | |||||
| #include "mindspore/lite/src/train/train_session.h" | |||||
| #include "mindspore/lite/src/kernel_factory.h" | |||||
| #include "mindspore/lite/src/param_value_lite.h" | |||||
| #include "common/utils.h" | |||||
| #include "mindspore/lite/src/ops/ops.h" | |||||
| #include "ir/anf.h" | |||||
| #include "mindspore/lite/src/ir/tensor.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "src/ir/primitive_value.h" | |||||
| #include "src/train/model_impl.h" | |||||
| namespace mindspore { | |||||
| namespace session { | |||||
| static std::vector<int> GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { | |||||
| auto nodeAbstract = anfNodePtr->abstract(); | |||||
| if (nodeAbstract != nullptr) { | |||||
| auto shape = nodeAbstract->GetShapeTrack(); | |||||
| if (!shape->isa<abstract::Shape>()) { | |||||
| MS_LOG(EXCEPTION) << "Not a Shape"; | |||||
| return {}; | |||||
| } | |||||
| auto dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||||
| return dims; | |||||
| } else { | |||||
| MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; | |||||
| return {}; | |||||
| } | |||||
| } | |||||
| static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { | |||||
| auto nodeAbstract = anfNodePtr->abstract(); | |||||
| if (nodeAbstract != nullptr) { | |||||
| return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode | |||||
| } else { | |||||
| MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; | |||||
| return schema::Format_NHWC; | |||||
| } | |||||
| } | |||||
| static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { | |||||
| auto nodeAbstract = anfNodePtr->abstract(); | |||||
| if (nodeAbstract != nullptr) { | |||||
| return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); | |||||
| } else { | |||||
| MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; | |||||
| return TypeId::kTypeUnknown; | |||||
| } | |||||
| } | |||||
| void TrainANFSession::Init(lite::Context *context) { | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_); | |||||
| } | |||||
| lite::tensor::Tensor *TrainANFSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { | |||||
| lite::tensor::Tensor *out_tensor = tensors_[anf_node]; | |||||
| if (out_tensor == NULL) { | |||||
| out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), | |||||
| GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); | |||||
| tensors_[anf_node] = out_tensor; | |||||
| } | |||||
| return out_tensor; | |||||
| } | |||||
| int TrainANFSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { | |||||
| auto return_node = kernel_graph->get_return(); | |||||
| auto node_list = TopoSort(return_node); | |||||
| auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_); | |||||
| for (auto &node : node_list) { | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| KernelRelation kernel_relation; | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| kernel_relation.node_full_name = cnode->fullname_with_scope(); | |||||
| kernel_relation.cnode = cnode; | |||||
| std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); | |||||
| if (cnode_io_indices == NULL) { | |||||
| MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); | |||||
| } else { | |||||
| for (int i = 0; i < cnode_io_indices[1].size(); i++) { | |||||
| AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); | |||||
| kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); | |||||
| } | |||||
| } | |||||
| lite::tensor::Tensor *tensor_ptr = nullptr; | |||||
| for (size_t index = 1; index < cnode->inputs().size(); ++index) { | |||||
| if (cnode->input(index)->isa<CNode>()) { | |||||
| auto input_cnode = cnode->input(index)->cast<CNodePtr>(); | |||||
| auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; | |||||
| // todo not support multi-outputs kernel sudo as spilt | |||||
| tensor_ptr = input_kernel_relation.output_tensor.front(); | |||||
| } else if (cnode->input(index)->isa<Parameter>()) { | |||||
| auto input_parameter = cnode->input(index)->cast<ParameterPtr>(); | |||||
| auto para = input_parameter->default_param(); | |||||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para); | |||||
| // auto dims = param_value->tensor_shape(); | |||||
| // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); | |||||
| tensor_ptr = GetTensorForAnfNode(cnode->input(index)); | |||||
| if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { | |||||
| tensor_ptr->SetData(param_value->tensor_addr()); | |||||
| } | |||||
| } else if (cnode->input(index)->isa<ValueNode>()) { | |||||
| auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>(); | |||||
| // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), | |||||
| // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); | |||||
| tensor_ptr = GetTensorForAnfNode(input_valuenode); | |||||
| // todo(yankai) | |||||
| } else { | |||||
| MS_ASSERT(false); | |||||
| } | |||||
| kernel_relation.input_tensor.push_back(tensor_ptr); | |||||
| } | |||||
| kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| GraphId TrainANFSession::graph_sum_ = 0; | |||||
| KernelGraphPtr TrainANFSession::NewKernelGraph() { | |||||
| auto graph = std::make_shared<KernelGraph>(); | |||||
| graph->set_graph_id(graph_sum_); | |||||
| graphs_[graph_sum_++] = graph; | |||||
| return graph; | |||||
| } | |||||
| std::shared_ptr<KernelGraph> TrainANFSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto graph = NewKernelGraph(); | |||||
| graph->set_return(func_graph->get_return()); | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| std::vector<CNodePtr> cnode_order; | |||||
| for (const auto &node : node_list) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| auto cn_node = node->cast<CNodePtr>(); | |||||
| cnode_order.push_back(cn_node); | |||||
| } | |||||
| } | |||||
| graph->set_execution_order(cnode_order); | |||||
| return graph; | |||||
| } | |||||
| GraphId TrainANFSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| auto graph = ConstructKernelGraph(func_graph); | |||||
| func_graph_ = func_graph; | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_LOG(INFO) << "Set kernel info"; | |||||
| SetKernelInfo(graph.get()); | |||||
| (void)BuildKernelInputAndOutputFromFuncGraph(graph); | |||||
| MS_LOG(INFO) << "Build kernel"; | |||||
| auto ret = BuildKernel(graph.get()); | |||||
| if (0 != ret) { | |||||
| MS_LOG(EXCEPTION) << "BuildKernel failed"; | |||||
| } | |||||
| // return the graph id to backend | |||||
| auto graph_id = graph->graph_id(); | |||||
| graphs_[graph_id] = graph; | |||||
| MS_LOG(INFO) << "Compile graph " << graph_id << " success"; | |||||
| return graph_id; | |||||
| } | |||||
| void TrainANFSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| std::vector<lite::tensor::Tensor *> *outputs) { | |||||
| auto &kernel_graph = graphs_[graph_id]; | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| MS_LOG(INFO) << "Bind input output address"; | |||||
| // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run | |||||
| // auto execution_order = kernel_graph->execution_order(); | |||||
| // Todo : hangangqiang | |||||
| // Reorder(&execution_order); | |||||
| // kernel_graph->set_execution_order(execution_order); | |||||
| MS_LOG(INFO) << "Run graph start"; | |||||
| auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, *outputs); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "Run graph failed"; | |||||
| } | |||||
| MS_LOG(INFO) << "Run graph end"; | |||||
| } | |||||
| void TrainANFSession::SetKernelInfo(const KernelGraph *kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto &kernel_nodes = kernel_graph->execution_order(); | |||||
| for (const auto &kernel_node : kernel_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| kernel_node->set_kernel_info(kernel_info); | |||||
| } | |||||
| } | |||||
| int TrainANFSession::BuildKernel(const KernelGraph *kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { | |||||
| std::string kernel_name = iter->first; | |||||
| KernelRelation anf_register = iter->second; | |||||
| MS_EXCEPTION_IF_NULL(anf_register.cnode); | |||||
| if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { | |||||
| continue; | |||||
| } | |||||
| auto value_node_prim = anf_register.cnode->input(0); | |||||
| MS_EXCEPTION_IF_NULL(value_node_prim); | |||||
| auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); | |||||
| MS_EXCEPTION_IF_NULL(node_primitive); | |||||
| auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); | |||||
| if (0 != ret) { | |||||
| MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; | |||||
| return ret; | |||||
| } | |||||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; | |||||
| auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, | |||||
| node_primitive, context_.get(), desc); | |||||
| if (nullptr == kernel) { | |||||
| MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; | |||||
| return -1; | |||||
| } | |||||
| std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel); | |||||
| kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); | |||||
| // kernel->train(); | |||||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||||
| #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include "include/context.h" | |||||
| #include "backend/session/session_basic.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| #include "mindspore/lite/src/train/lite_kernel_runtime.h" | |||||
| // #include "backend/session/session_factory.h" | |||||
| namespace mindspore { | |||||
| namespace lite::tensor { | |||||
| class Tensor; | |||||
| } | |||||
| namespace session { | |||||
| struct KernelRelation { | |||||
| std::string node_full_name; | |||||
| std::vector<lite::tensor::Tensor *> input_tensor; | |||||
| std::vector<lite::tensor::Tensor *> output_tensor; | |||||
| CNodePtr cnode; | |||||
| }; | |||||
| class TrainANFSession { | |||||
| public: | |||||
| explicit TrainANFSession(lite::Context *context) { Init(context); } | |||||
| ~TrainANFSession() = default; | |||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| std::vector<lite::tensor::Tensor *> *outputs); | |||||
| // void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override | |||||
| // {}; | |||||
| protected: | |||||
| void Init(lite::Context *context); | |||||
| std::shared_ptr<lite::Context> context_ = nullptr; | |||||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||||
| static GraphId graph_sum_; | |||||
| KernelGraphPtr NewKernelGraph(); | |||||
| private: | |||||
| // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||||
| // GraphId CompileGraph(const char *model_buf, size_t size); | |||||
| std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph); | |||||
| int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); | |||||
| lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); | |||||
| void SetKernelInfo(const KernelGraph *kernel_graph); | |||||
| int BuildKernel(const KernelGraph *kernel_graph); | |||||
| lite::LiteInferKernelRuntime runtime_; | |||||
| std::map<std::string, KernelRelation> kernel_relation_infos_; | |||||
| FuncGraphPtr func_graph_ = NULL; | |||||
| std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_; | |||||
| }; | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||||
| @@ -15,6 +15,8 @@ | |||||
| */ | */ | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "include/context.h" | |||||
| #include "mindspore/ccsrc/runtime/device/kernel_info.h" | |||||
| #include "mindspore/lite/src/train/train_session.h" | #include "mindspore/lite/src/train/train_session.h" | ||||
| #include "mindspore/lite/src/kernel_factory.h" | #include "mindspore/lite/src/kernel_factory.h" | ||||
| #include "mindspore/lite/src/param_value_lite.h" | #include "mindspore/lite/src/param_value_lite.h" | ||||
| @@ -25,6 +27,7 @@ | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "src/ir/primitive_value.h" | #include "src/ir/primitive_value.h" | ||||
| #include "src/train/model_impl.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -57,16 +60,32 @@ static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { | |||||
| static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { | static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { | ||||
| auto nodeAbstract = anfNodePtr->abstract(); | auto nodeAbstract = anfNodePtr->abstract(); | ||||
| if (nodeAbstract != nullptr) { | if (nodeAbstract != nullptr) { | ||||
| return nodeAbstract->GetTypeTrack()->type_id(); | |||||
| return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); | |||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; | MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; | ||||
| return TypeId::kTypeUnknown; | return TypeId::kTypeUnknown; | ||||
| } | } | ||||
| } | } | ||||
| void TrainSession::Init(lite::Context *context) { | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_); | |||||
| } | |||||
| lite::tensor::Tensor *TrainSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { | |||||
| lite::tensor::Tensor *out_tensor = tensors_[anf_node]; | |||||
| if (out_tensor == NULL) { | |||||
| out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), | |||||
| GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); | |||||
| tensors_[anf_node] = out_tensor; | |||||
| } | |||||
| return out_tensor; | |||||
| } | |||||
| int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { | int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { | ||||
| auto return_node = kernel_graph->get_return(); | auto return_node = kernel_graph->get_return(); | ||||
| auto node_list = TopoSort(return_node); | auto node_list = TopoSort(return_node); | ||||
| auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_); | |||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| @@ -75,11 +94,16 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| kernel_relation.node_full_name = cnode->fullname_with_scope(); | kernel_relation.node_full_name = cnode->fullname_with_scope(); | ||||
| kernel_relation.cnode = cnode; | kernel_relation.cnode = cnode; | ||||
| auto *out_tensor = | |||||
| new tensor::Tensor(GetAnfNodeOutTypeId(cnode), GetAnfNodeOutDims(cnode), GetAnfNodeFormat(cnode), | |||||
| schema::NodeType_Parameter); | |||||
| kernel_relation.output_tensor.push_back(out_tensor); | |||||
| tensor::Tensor *tensor_ptr = nullptr; | |||||
| std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); | |||||
| if (cnode_io_indices == NULL) { | |||||
| MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); | |||||
| } else { | |||||
| for (int i = 0; i < cnode_io_indices[1].size(); i++) { | |||||
| AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); | |||||
| kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); | |||||
| } | |||||
| } | |||||
| lite::tensor::Tensor *tensor_ptr = nullptr; | |||||
| for (size_t index = 1; index < cnode->inputs().size(); ++index) { | for (size_t index = 1; index < cnode->inputs().size(); ++index) { | ||||
| if (cnode->input(index)->isa<CNode>()) { | if (cnode->input(index)->isa<CNode>()) { | ||||
| auto input_cnode = cnode->input(index)->cast<CNodePtr>(); | auto input_cnode = cnode->input(index)->cast<CNodePtr>(); | ||||
| @@ -90,17 +114,17 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k | |||||
| auto input_parameter = cnode->input(index)->cast<ParameterPtr>(); | auto input_parameter = cnode->input(index)->cast<ParameterPtr>(); | ||||
| auto para = input_parameter->default_param(); | auto para = input_parameter->default_param(); | ||||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para); | auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para); | ||||
| auto dims = param_value->tensor_shape(); | |||||
| tensor_ptr = new tensor::Tensor(param_value->tensor_type(), dims, schema::Format_NHWC, | |||||
| schema::NodeType_ValueNode); // XXX TODO -- extract Format from AnfNode | |||||
| if (param_value->tensor_size() != 0) { | |||||
| // auto dims = param_value->tensor_shape(); | |||||
| // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); | |||||
| tensor_ptr = GetTensorForAnfNode(cnode->input(index)); | |||||
| if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { | |||||
| tensor_ptr->SetData(param_value->tensor_addr()); | tensor_ptr->SetData(param_value->tensor_addr()); | ||||
| } | } | ||||
| } else if (cnode->input(index)->isa<ValueNode>()) { | } else if (cnode->input(index)->isa<ValueNode>()) { | ||||
| auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>(); | auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>(); | ||||
| tensor_ptr = new tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), GetAnfNodeOutDims(input_valuenode), | |||||
| schema::Format_NHWC, | |||||
| schema::NodeType_Parameter); // XXX TODO -- extract Format from AnfNode | |||||
| // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), | |||||
| // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); | |||||
| tensor_ptr = GetTensorForAnfNode(input_valuenode); | |||||
| // todo(yankai) | // todo(yankai) | ||||
| } else { | } else { | ||||
| MS_ASSERT(false); | MS_ASSERT(false); | ||||
| @@ -111,7 +135,7 @@ int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &k | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| #if 0 | |||||
| GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | ||||
| auto graph_id = graph_sum_; | auto graph_id = graph_sum_; | ||||
| auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); | auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); | ||||
| @@ -124,6 +148,17 @@ GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrLi | |||||
| } | } | ||||
| GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } | GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } | ||||
| #else | |||||
| GraphId TrainSession::graph_sum_ = 0; | |||||
| KernelGraphPtr TrainSession::NewKernelGraph() { | |||||
| auto graph = std::make_shared<KernelGraph>(); | |||||
| graph->set_graph_id(graph_sum_); | |||||
| graphs_[graph_sum_++] = graph; | |||||
| return graph; | |||||
| } | |||||
| #endif | |||||
| std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -141,14 +176,14 @@ std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphP | |||||
| graph->set_execution_order(cnode_order); | graph->set_execution_order(cnode_order); | ||||
| return graph; | return graph; | ||||
| } | } | ||||
| GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | ||||
| auto graph = ConstructKernelGraph(func_graph); | auto graph = ConstructKernelGraph(func_graph); | ||||
| func_graph_ = func_graph; | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_LOG(INFO) << "Set kernel info"; | MS_LOG(INFO) << "Set kernel info"; | ||||
| SetKernelInfo(graph.get()); | SetKernelInfo(graph.get()); | ||||
| (void) BuildKernelInputAndOutputFromFuncGraph(graph); | |||||
| (void)BuildKernelInputAndOutputFromFuncGraph(graph); | |||||
| MS_LOG(INFO) << "Build kernel"; | MS_LOG(INFO) << "Build kernel"; | ||||
| auto ret = BuildKernel(graph.get()); | auto ret = BuildKernel(graph.get()); | ||||
| if (0 != ret) { | if (0 != ret) { | ||||
| @@ -162,18 +197,18 @@ GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| return graph_id; | return graph_id; | ||||
| } | } | ||||
| void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Tensor *> &inputs, | |||||
| std::vector<tensor::Tensor *> &outputs) { | |||||
| void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| std::vector<lite::tensor::Tensor *> *outputs) { | |||||
| auto &kernel_graph = graphs_[graph_id]; | auto &kernel_graph = graphs_[graph_id]; | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| MS_LOG(INFO) << "Bind input output address"; | MS_LOG(INFO) << "Bind input output address"; | ||||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | |||||
| // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run | |||||
| // auto execution_order = kernel_graph->execution_order(); | // auto execution_order = kernel_graph->execution_order(); | ||||
| // Todo : hangangqiang | // Todo : hangangqiang | ||||
| // Reorder(&execution_order); | // Reorder(&execution_order); | ||||
| // kernel_graph->set_execution_order(execution_order); | // kernel_graph->set_execution_order(execution_order); | ||||
| MS_LOG(INFO) << "Run graph start"; | MS_LOG(INFO) << "Run graph start"; | ||||
| auto ret = runtime_.Run(kernel_graph.get(), (std::vector<tensor::Tensor *> &) inputs, outputs); | |||||
| auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, outputs); | |||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(EXCEPTION) << "Run graph failed"; | MS_LOG(EXCEPTION) << "Run graph failed"; | ||||
| } | } | ||||
| @@ -199,34 +234,34 @@ int TrainSession::BuildKernel(const KernelGraph *kernel_graph) { | |||||
| if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { | if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| lite::Context context; | |||||
| context.deviceCtx.type = lite::DeviceType::DT_CPU; | |||||
| auto value_node_prim = anf_register.cnode->input(0); | auto value_node_prim = anf_register.cnode->input(0); | ||||
| MS_EXCEPTION_IF_NULL(value_node_prim); | MS_EXCEPTION_IF_NULL(value_node_prim); | ||||
| auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim); | auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| auto node_primitive = (lite::Primitive *) (prim->GetPrimitive()); | |||||
| auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); | |||||
| MS_EXCEPTION_IF_NULL(node_primitive); | MS_EXCEPTION_IF_NULL(node_primitive); | ||||
| auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); | auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); | ||||
| if (0 != ret) { | if (0 != ret) { | ||||
| MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; | MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, node_primitive->Type()}; | |||||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; | |||||
| auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, | auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, | ||||
| node_primitive, &context, desc); | |||||
| node_primitive, context_.get(), desc); | |||||
| if (nullptr == kernel) { | if (nullptr == kernel) { | ||||
| MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; | MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| kernel->train(); | |||||
| auto *kernel_info = anf_register.cnode->kernel_info(); | |||||
| std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel); | std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel); | ||||
| kernel_info->set_kernel_mod(kernel_mod); | |||||
| kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); | |||||
| // kernel->train(); | |||||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,47 +19,58 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | |||||
| #include "include/context.h" | |||||
| #include "backend/session/session_basic.h" | #include "backend/session/session_basic.h" | ||||
| #include "backend/session/kernel_graph.h" | #include "backend/session/kernel_graph.h" | ||||
| #include "mindspore/lite/src/train/lite_kernel_runtime.h" | #include "mindspore/lite/src/train/lite_kernel_runtime.h" | ||||
| #include "backend/session/session_factory.h" | |||||
| // #include "backend/session/session_factory.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite::tensor { | namespace lite::tensor { | ||||
| class Tensor; | class Tensor; | ||||
| } | } | ||||
| namespace session { | namespace session { | ||||
| struct KernelRelation { | struct KernelRelation { | ||||
| std::string node_full_name; | |||||
| std::vector<tensor::Tensor *> input_tensor; | |||||
| std::vector<tensor::Tensor *> output_tensor; | |||||
| CNodePtr cnode; | |||||
| std::string node_full_name; | |||||
| std::vector<lite::tensor::Tensor *> input_tensor; | |||||
| std::vector<lite::tensor::Tensor *> output_tensor; | |||||
| CNodePtr cnode; | |||||
| }; | }; | ||||
| class TrainSession : public SessionBasic { | |||||
| class TrainSession { | |||||
| public: | public: | ||||
| TrainSession() : SessionBasic() {} | |||||
| ~TrainSession() override = default; | |||||
| void Init(uint32_t device_id) override { | |||||
| SessionBasic::Init(device_id); | |||||
| context_ = std::make_shared<Context>(kCPUDevice, device_id); | |||||
| } | |||||
| explicit TrainSession(lite::Context * context) { Init(context); } | |||||
| ~TrainSession() = default; | |||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | |||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| std::vector<lite::tensor::Tensor *> *outputs); | |||||
| // void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override | |||||
| // {}; | |||||
| protected: | |||||
| void Init(lite::Context *context); | |||||
| std::shared_ptr<lite::Context> context_ = nullptr; | |||||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||||
| static GraphId graph_sum_; | |||||
| KernelGraphPtr NewKernelGraph(); | |||||
| private: | private: | ||||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||||
| GraphId CompileGraph(const char *model_buf, size_t size); | |||||
| // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||||
| // GraphId CompileGraph(const char *model_buf, size_t size); | |||||
| std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph); | std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph); | ||||
| int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); | int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); | ||||
| lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); | |||||
| void SetKernelInfo(const KernelGraph *kernel_graph); | void SetKernelInfo(const KernelGraph *kernel_graph); | ||||
| int BuildKernel(const KernelGraph *kernel_graph); | int BuildKernel(const KernelGraph *kernel_graph); | ||||
| lite::LiteInferKernelRuntime runtime_; | lite::LiteInferKernelRuntime runtime_; | ||||
| std::map<std::string, KernelRelation> kernel_relation_infos_; | std::map<std::string, KernelRelation> kernel_relation_infos_; | ||||
| FuncGraphPtr func_graph_ = NULL; | |||||
| std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_; | |||||
| }; | }; | ||||
| MS_REG_SESSION(kCPUDevice, TrainSession); | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | ||||
| @@ -87,6 +87,16 @@ file(GLOB KERNEL_OP_SRC | |||||
| ${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc | ${LITE_DIR}/src/runtime/kernel/arm/nnacl/int8/*.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc | ${LITE_DIR}/src/runtime/kernel/arm/nnacl/quantization/*.cc | ||||
| ) | ) | ||||
| file(GLOB KERNEL_OP_TRAIN_SRC | |||||
| ${LITE_DIR}/src/runtime/kernel/arm/nnacl/fp32_grad/*.cc | |||||
| ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc | |||||
| ) | |||||
| if (SUPPORT_TRAIN) | |||||
| list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) | |||||
| endif() | |||||
| if (PLATFORM_ARM64) | if (PLATFORM_ARM64) | ||||
| # assembly | # assembly | ||||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s | file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/nnacl/assembly/arm64/*.s | ||||
| @@ -245,12 +255,13 @@ if (SUPPORT_TRAIN) | |||||
| # ${SRC_DIR}/device/kernel_info.cc | # ${SRC_DIR}/device/kernel_info.cc | ||||
| # ${SRC_DIR}/device/kernel_runtime.cc | # ${SRC_DIR}/device/kernel_runtime.cc | ||||
| # ${SRC_DIR}/device/lite/kernel_runtime_extends.cc | # ${SRC_DIR}/device/lite/kernel_runtime_extends.cc | ||||
| ${LITE_DIR}/src/common/anf_importer/anf_importer.cc | |||||
| ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc | |||||
| ${LITE_DIR}/src/ir/primitive_value.cc | |||||
| ${LITE_DIR}/src/train/lite_kernel_runtime.cc | |||||
| ${LITE_DIR}/src/train/train_session.cc | |||||
| ${LITE_DIR}/src/train/model_impl.cc | |||||
| # ${LITE_DIR}/src/common/anf_importer/anf_importer.cc | |||||
| # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc | |||||
| # ${LITE_DIR}/src/ir/primitive_value.cc | |||||
| # ${LITE_DIR}/src/train/lite_kernel_runtime.cc | |||||
| # ${LITE_DIR}/src/train/train_session.cc | |||||
| # ${LITE_DIR}/src/train/model_impl.cc | |||||
| ${LITE_DIR}/src/lite_session.cc # temporary | |||||
| ) | ) | ||||
| else() | else() | ||||
| set(TEST_LITE_SRC | set(TEST_LITE_SRC | ||||
| @@ -265,6 +276,10 @@ file(GLOB_RECURSE TEST_CASE_KERNEL_SRC | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc | ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc | ||||
| ) | ) | ||||
| file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc | |||||
| ) | |||||
| set(TEST_SRC | set(TEST_SRC | ||||
| ${TEST_LITE_SRC} | ${TEST_LITE_SRC} | ||||
| ${TEST_MINDDATA_SRC} | ${TEST_MINDDATA_SRC} | ||||
| @@ -278,7 +293,9 @@ set(TEST_SRC | |||||
| if (SUPPORT_TRAIN) | if (SUPPORT_TRAIN) | ||||
| set(TEST_SRC | set(TEST_SRC | ||||
| ${TEST_SRC} | ${TEST_SRC} | ||||
| ${TEST_DIR}/ut/src/train_test.cc | |||||
| ${TEST_CASE_KERNEL_TRAIN_SRC} | |||||
| # ${TEST_DIR}/ut/src/train_test.cc | |||||
| ${TEST_DIR}/ut/src/infer_test.cc # temporary | |||||
| ) | ) | ||||
| else() | else() | ||||
| set(TEST_SRC | set(TEST_SRC | ||||
| @@ -13,7 +13,6 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -25,7 +24,7 @@ | |||||
| #include "mindspore/lite/src/kernel_registry.h" | #include "mindspore/lite/src/kernel_registry.h" | ||||
| #include "mindspore/lite/src/ir/tensor.h" | #include "mindspore/lite/src/ir/tensor.h" | ||||
| #include "mindspore/lite/src/lite_kernel.h" | #include "mindspore/lite/src/lite_kernel.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestActGradFp32 : public mindspore::Common { | class TestActGradFp32 : public mindspore::Common { | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "src/common/file_utils_ext.h" | #include "src/common/file_utils_ext.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h" | #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/reduce.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | #include "mindspore/lite/src/kernel_registry.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | #include "mindspore/lite/src/kernel_registry.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -21,8 +21,8 @@ | |||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "src/common/file_utils_ext.h" | #include "src/common/file_utils_ext.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h" | #include "mindspore/lite/src/runtime/kernel/arm/nnacl/conv_parameter.h" | ||||
| #include "mindspore/lite/src/kernel_registry.h" | #include "mindspore/lite/src/kernel_registry.h" | ||||
| @@ -22,8 +22,8 @@ | |||||
| #include "mindspore/lite/src/kernel_registry.h" | #include "mindspore/lite/src/kernel_registry.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "src/runtime/kernel/arm/fp32/pooling_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/pooling_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/pooling_grad.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestPoolingGradFp32 : public mindspore::Common { | class TestPoolingGradFp32 : public mindspore::Common { | ||||
| @@ -0,0 +1,92 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "src/common/file_utils.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h" | |||||
| #include "src/kernel_registry.h" | |||||
| namespace mindspore { | |||||
| class TestSoftmaxCrossEntropyFp32 : public mindspore::Common { | |||||
| public: | |||||
| TestSoftmaxCrossEntropyFp32() {} | |||||
| }; | |||||
| TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||||
| // prepare stage | |||||
| SoftmaxCrossEntropyParameter *sce_param = new SoftmaxCrossEntropyParameter(); | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/operators/sce_fp32_1_y_6_4.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::vector<int> dim_y({6, 4}); | |||||
| lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); | |||||
| y_tensor.SetData(input_data); | |||||
| std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin"; | |||||
| auto ll_labels = reinterpret_cast<int64 *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); | |||||
| auto labels = new int[6]; | |||||
| for (int i = 0; i < 6; i++) labels[i] = static_cast<int>(ll_labels[i]); | |||||
| std::vector<int> dim_l({6}); | |||||
| lite::tensor::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l); | |||||
| l_tensor.SetData(labels); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&y_tensor, &l_tensor}; | |||||
| auto loss = new float[1]; | |||||
| std::vector<int> dim_dw({1}); | |||||
| lite::tensor::Tensor loss_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||||
| loss_tensor.SetData(loss); | |||||
| auto grad = new float[24]; | |||||
| lite::tensor::Tensor grad_tensor(TypeId::kNumberTypeFloat32, dim_y); | |||||
| grad_tensor.SetData(grad); | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&grad_tensor, &loss_tensor}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftmaxCrossEntropy}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(sce_param), NULL, desc, nullptr); | |||||
| kernel_obj->Run(); | |||||
| printf("==================total loss=================\n"); | |||||
| std::cout << loss[0] << " ," << std::endl; | |||||
| printf("==================Testing Grad===============\n"); | |||||
| std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin"; | |||||
| lite::CompareOutput(loss, output_path); | |||||
| ((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train(); | |||||
| kernel_obj->Run(); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 12; i++) { | |||||
| std::cout << grad[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin"; | |||||
| lite::CompareOutput(grad, grad_path); | |||||
| delete sce_param; | |||||
| l_tensor.SetData(NULL); | |||||
| y_tensor.SetData(NULL); | |||||
| MS_LOG(INFO) << "SoftmaxCrossEntropyFp32 passed"; | |||||
| } | |||||
| } // namespace mindspore | |||||