| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INCLUDE_MODEL_H | |||
| #define MINDSPORE_LITE_INCLUDE_MODEL_H | |||
| #ifndef MINDSPORE_LITE_INCLUDE_MODEL_H_ | |||
| #define MINDSPORE_LITE_INCLUDE_MODEL_H_ | |||
| #include <vector> | |||
| #include "include/lite_utils.h" | |||
| @@ -46,14 +46,14 @@ struct Model { | |||
| static Model *Import(const char *model_buf, size_t size); | |||
| /// \brief Free meta graph temporary buffer | |||
| void Free(); | |||
| virtual void Free(); | |||
| /// \brief Free all temporay buffer | |||
| void Destroy(); | |||
| /// \brief Model destruct, free all memory | |||
| ~Model(); | |||
| virtual ~Model(); | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_INCLUDE_MODEL_H | |||
| #endif // MINDSPORE_LITE_INCLUDE_MODEL_H_ | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_MODEL_H_ | |||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_MODEL_H_ | |||
| #include <vector> | |||
| #include "include/model.h" | |||
| namespace mindspore::lite { | |||
| struct TrainModel : public lite::Model { | |||
| /// \brief Static method to create a TrainModel pointer. | |||
| /// | |||
| /// \param[in] model_buf Define the buffer read from a model file. | |||
| /// \param[in] size Define bytes number of model buffer. | |||
| /// | |||
| /// \return Pointer of MindSpore Lite TrainModel. | |||
| static TrainModel *Import(const char *model_buf, size_t size); | |||
| /// \brief Free meta graph temporary buffer | |||
| void Free() override; | |||
| /// \brief TrainModel destruct, free all memory | |||
| virtual ~TrainModel(); | |||
| /// \brief Export Model into buf. | |||
| /// | |||
| /// \param[in] buf Define the buffer to Export into. If nullptr, buf will be allocated | |||
| /// \param[in] len size of the buffer. | |||
| /// | |||
| /// \return Pointer to buffer with exported model | |||
| char* ExportBuf(char* buf, size_t* len) const; | |||
| size_t buf_size_; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_MODEL_H_ | |||
| @@ -22,7 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| struct Model; | |||
| struct TrainModel; | |||
| } | |||
| namespace session { | |||
| @@ -35,24 +35,19 @@ class TrainSession : public lite::LiteSession { | |||
| const session::KernelCallBack &after = nullptr) override; | |||
| int CompileGraph(lite::Model *model) override; | |||
| virtual void ReplaceOps(); | |||
| virtual void* ExportToBuf(lite::Model *model, void* buf, size_t* len) const; | |||
| // todo: output tensors by tensor name | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputMap() const; | |||
| std::vector<tensor::MSTensor *> GetOutputsByName(const std::string &node_name) const; | |||
| virtual void* ExportToBuf(char* buf, size_t* len) const; | |||
| virtual void train(); | |||
| bool is_train() { return train_mode_ == true; } | |||
| virtual void eval(); | |||
| bool is_eval() { return train_mode_ == false; } | |||
| virtual void Train(); | |||
| bool IsTrain() { return train_mode_ == true; } | |||
| virtual void Eval(); | |||
| bool IsEval() { return train_mode_ == false; } | |||
| protected: | |||
| virtual void ReplaceOps(); | |||
| bool train_mode_ = false; | |||
| lite::Model *model_ = nullptr; | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> ext_output_map_; | |||
| // private: | |||
| lite::TrainModel *model_ = nullptr; | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_; | |||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -56,7 +56,7 @@ void FusedBatchNormFp32(const void *input, const void *scale, const void *offset | |||
| void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_mean, float *run_var, | |||
| BatchNormParameter *param, float *save_mean, float *save_inv_var) { | |||
| float N = param->channel_ * param->unit_; | |||
| float N = (float)param->unit_; | |||
| for (int i = 0; i < param->unit_; i++) { | |||
| for (int f = 0; f < param->channel_; f++) { | |||
| int idx = i * param->channel_ + f; | |||
| @@ -64,11 +64,12 @@ void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_me | |||
| run_var[f] += input[idx] * input[idx]; | |||
| } | |||
| } | |||
| const float VN = (N > 1.0f) ? (N - 1.0f) : 1.0f; | |||
| for (int f = 0; f < param->channel_; f++) { | |||
| run_mean[f] = run_mean[f] / N; | |||
| run_var[f] = run_var[f] / N - run_mean[f] * run_mean[f]; | |||
| run_var[f] = run_var[f] / VN - run_mean[f] * run_mean[f]; | |||
| save_mean[f] = momentum * save_mean[f] + (1 - momentum) * run_mean[f]; | |||
| float inv_var = 1.f / sqrt(run_var[f] + param->epsilon_); | |||
| const float inv_var = 1.f / sqrt(run_var[f] + param->epsilon_); | |||
| save_inv_var[f] = momentum * save_inv_var[f] + (1 - momentum) * inv_var; | |||
| } | |||
| } | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_ACTIVATION_GRAD_H_ | |||
| #define MINDSPORE_LITE_NNACL_ACTIVATION_GRAD_H_ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| @@ -42,4 +42,4 @@ int HSigmoidGrad(float *src0, float *src1, int length, float *dst); | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_ACTIVATION_GRAD_H_ | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||
| @@ -25,4 +25,4 @@ void ElementMulAndDivNegSquare(const float *a, const float *b, const float *deno | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_ARITHMETIC_GRAD_H_ | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_BATCH_NORM_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_BATCH_NORM_H_ | |||
| #include "nnacl/op_base.h" | |||
| @@ -39,4 +39,4 @@ void backwardScale(const float *x, const float *mean, const float *invar, const | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_ | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_BATCH_NORM_H_ | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_PACK_EXT_H_ | |||
| #define MINDSPORE_LITE_NNACL_PACK_EXT_H_ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_PACK_EXT_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_PACK_EXT_H_ | |||
| #include "nnacl/conv_parameter.h" | |||
| @@ -29,4 +29,4 @@ void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_PACK_EXT_H | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_PACK_EXT_H_ | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_REDUCE_GRAD_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_REDUCE_GRAD_H_ | |||
| #include <stddef.h> | |||
| @@ -27,4 +27,4 @@ void ReduceSumByAxes(const float *input, const int *input_dims, float *output, c | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_ | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_REDUCE_GRAD_H_ | |||
| @@ -54,6 +54,7 @@ if (SUPPORT_TRAIN) | |||
| ${ANF_SRC} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | |||
| ) | |||
| endif () | |||
| @@ -41,9 +41,12 @@ static float CompareOutputRelativeData(float *output_data, float *correct_data, | |||
| int CompareRelativeOutput(float *output_data, std::string file_path) { | |||
| size_t output_size; | |||
| auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||
| if (ground_truth == nullptr) { | |||
| return 1; | |||
| } | |||
| size_t output_num = output_size / sizeof(float); | |||
| int error = CompareOutputRelativeData(output_data, ground_truth, output_num); | |||
| delete [] ground_truth; | |||
| delete[] ground_truth; | |||
| if (error > 1e-4) { | |||
| return 1; | |||
| } | |||
| @@ -55,9 +58,8 @@ float RelativeOutputError(float *output_data, std::string file_path) { | |||
| auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||
| size_t output_num = output_size / sizeof(float); | |||
| float error = CompareOutputRelativeData(output_data, ground_truth, output_num); | |||
| delete [] ground_truth; | |||
| delete[] ground_truth; | |||
| return error; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -278,16 +278,13 @@ int LiteSession::CompileGraph(Model *model) { | |||
| is_running_.store(false); | |||
| return ret; | |||
| } | |||
| ret = executor->Prepare(this->kernels_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare kernels failed: " << ret; | |||
| is_running_.store(false); | |||
| return ret; | |||
| } | |||
| #ifndef SUPPORT_TRAIN | |||
| model->Free(); | |||
| #endif | |||
| is_running_.store(false); | |||
| return RET_OK; | |||
| } | |||
| @@ -21,7 +21,7 @@ | |||
| #include "include/version.h" | |||
| namespace mindspore::lite { | |||
| namespace { | |||
| bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) { | |||
| for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) { | |||
| Model::Node *node = new (std::nothrow) Model::Node(); | |||
| @@ -66,7 +66,6 @@ bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) { | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| Model *Model::Import(const char *model_buf, size_t size) { | |||
| if (model_buf == nullptr) { | |||
| @@ -39,11 +39,11 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP | |||
| return RET_ERROR; | |||
| } | |||
| auto attr = std::make_unique<schema::ActivationGradT>(); | |||
| if (prim.name() == "ReLU") { | |||
| if (prim.name() == "ReluGrad") { | |||
| attr->type = schema::ActivationType_RELU; | |||
| } else if (prim.name() == "Sigmoid") { | |||
| } else if (prim.name() == "SigmoidGrad") { | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| } else if (prim.name() == "ReLU6") { | |||
| } else if (prim.name() == "Relu6Grad") { | |||
| attr->type = schema::ActivationType_RELU6; | |||
| } | |||
| // auto alpha = GetValue<float>(prim.GetAttr("alpha")); | |||
| @@ -64,7 +64,7 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat | |||
| MS_LOG(ERROR) << "value_as_ActivationGrad return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateActivationGrad(*fbb, attr->type()); | |||
| auto val_offset = schema::CreateActivationGrad(*fbb, attr->type(), attr->alpha()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| @@ -23,6 +23,29 @@ int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; } | |||
| void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; } | |||
| int AddN::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_AddN; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_AddN) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::AddNT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_ADDN_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_ADDN_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -32,6 +32,7 @@ class AddN : public PrimitiveC { | |||
| AddN() = default; | |||
| explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetN(int n); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| AddN() = default; | |||
| @@ -43,4 +44,4 @@ class AddN : public PrimitiveC { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_ADDN_H_ | |||
| @@ -43,8 +43,11 @@ int BNGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->eps = GetValue<float>(prim.GetAttr("eps")); | |||
| attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | |||
| // FusedBatchNormGrad dows not get this attribute | |||
| if (prim.GetAttr("eps") != nullptr) { | |||
| attr->eps = GetValue<float>(prim.GetAttr("eps")); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| @@ -27,6 +27,32 @@ void FusedBatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsFused | |||
| void FusedBatchNorm::SetMomentum(float momentum) { this->primitive_->value.AsFusedBatchNorm()->momentum = momentum; } | |||
| void FusedBatchNorm::SetSpatial(int spatial) { this->primitive_->value.AsFusedBatchNorm()->spatial = spatial; } | |||
| int FusedBatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_FusedBatchNorm) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::FusedBatchNormT(); | |||
| attr->epsilon = GetValue<float>(prim.GetAttr("epsilon")); | |||
| attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int FusedBatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_FUSED_BATCHNORM_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_FUSED_BATCHNORM_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -34,6 +34,7 @@ class FusedBatchNorm : public PrimitiveC { | |||
| void SetEpsilon(float epsilon); | |||
| void SetMomentum(float momentum); | |||
| void SetSpatial(int spatial); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| FusedBatchNorm() = default; | |||
| @@ -46,4 +47,4 @@ class FusedBatchNorm : public PrimitiveC { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_FUSED_BATCHNORM_H_ | |||
| @@ -124,6 +124,7 @@ | |||
| #include "src/ops/sparse_to_dense.h" | |||
| #include "src/ops/detection_post_process.h" | |||
| #include "src/ops/dropout.h" | |||
| #include "src/ops/real_div.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #endif | |||
| @@ -355,6 +356,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| const auto &op_type = prim.name(); | |||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") { | |||
| return NewPrimitiveC<Activation>(prim, inputs, quantType); | |||
| } else if (op_type == "AddN") { | |||
| return NewPrimitiveC<AddN>(prim, inputs, quantType); | |||
| } else if (op_type == "BatchNorm") { | |||
| return NewPrimitiveC<BatchNorm>(prim, inputs, quantType); | |||
| } else if (op_type == "BiasAdd") { | |||
| @@ -369,6 +372,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Dequant>(prim, inputs, quantType); | |||
| } else if (op_type == "Flatten") { | |||
| return NewPrimitiveC<Flatten>(prim, inputs, quantType); | |||
| } else if (op_type == "FusedBatchNorm") { | |||
| return NewPrimitiveC<FusedBatchNorm>(prim, inputs, quantType); | |||
| } else if (op_type == "make_tuple") { | |||
| return NewPrimitiveC<MakeTuple>(prim, inputs, quantType); | |||
| } else if (op_type == "MatMul") { | |||
| @@ -379,8 +384,20 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Pooling>(prim, inputs, quantType); | |||
| } else if (op_type == "Quant") { | |||
| return NewPrimitiveC<Quant>(prim, inputs, quantType); | |||
| } else if (op_type == "RealDiv") { | |||
| return NewPrimitiveC<RealDiv>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceMax") { | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceMean") { | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceMin") { | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceProd") { | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceSum") { | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceSumSquare") { | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "Reshape") { | |||
| return NewPrimitiveC<Reshape>(prim, inputs, quantType); | |||
| } else if (op_type == "TensorAdd") { | |||
| @@ -402,7 +419,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| } else if (op_type == "Cast") { | |||
| return NewPrimitiveC<Cast>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | |||
| @@ -424,14 +440,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "FlattenGrad") { | |||
| return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType); | |||
| #endif | |||
| #ifdef SUPPORT_TRAIN0 | |||
| } else if (op_type == "PowerGrad") { | |||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "NegGrad") { | |||
| return NewPrimitiveC<NegGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "LogGrad") { | |||
| return NewPrimitiveC<LogGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "FusedBatchNormGrad") { | |||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "Tile") { | |||
| return NewPrimitiveC<Tile>(prim, inputs, quantType); | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << op_type; | |||
| @@ -929,9 +941,9 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||
| case schema::PrimitiveType_MulGrad: | |||
| return NewPrimitiveC<ArithmeticGrad>(primitive); | |||
| case schema::PrimitiveType_DivGrad: | |||
| return NewPrimitiveC<ArithmeticGrad>(primitive); | |||
| return NewPrimitiveC<ArithmeticGrad>(primitive); | |||
| case schema::PrimitiveType_SoftmaxCrossEntropy: | |||
| return NewPrimitiveC<SoftmaxCrossEntropy>(primitive); | |||
| return NewPrimitiveC<SoftmaxCrossEntropy>(primitive); | |||
| case schema::PrimitiveType_NegGrad: | |||
| return NewPrimitiveC<NegGrad>(primitive); | |||
| case schema::PrimitiveType_LogGrad: | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/real_div.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int RealDiv::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_RealDiv; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_RealDiv) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::RealDivT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_REAL_DIV_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_REAL_DIV_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/ops/arithmetic.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class RealDiv : public Arithmetic { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(RealDiv, Arithmetic); | |||
| RealDiv() = default; | |||
| explicit RealDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| RealDiv() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { | |||
| return RET_ERROR; | |||
| } | |||
| #endif | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_REAL_DIV_H_ | |||
| @@ -49,7 +49,19 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->mode = schema::ReduceMode_ReduceMean; | |||
| if (prim.name() == "ReduceMean") { | |||
| attr->mode = schema::ReduceMode_ReduceMean; | |||
| } else if (prim.name() == "ReduceSum") { | |||
| attr->mode = schema::ReduceMode_ReduceSum; | |||
| } else if (prim.name() == "ReduceMax") { | |||
| attr->mode = schema::ReduceMode_ReduceMax; | |||
| } else if (prim.name() == "ReduceMin") { | |||
| attr->mode = schema::ReduceMode_ReduceMin; | |||
| } else if (prim.name() == "ReduceProd") { | |||
| attr->mode = schema::ReduceMode_ReduceProd; | |||
| } else if (prim.name() == "ReduceSumSquare") { | |||
| attr->mode = schema::ReduceMode_ReduceSumSquare; | |||
| } | |||
| attr->keepDims = GetValue<bool>(prim.GetAttr("keep_dims")); | |||
| if (inputs.size() == kAnfPopulaterTwo) { | |||
| @@ -94,6 +94,7 @@ int SoftmaxCrossEntropy::InferShape(std::vector<Tensor *> inputs, std::vector<Te | |||
| outshape.push_back(1); | |||
| out->set_shape(outshape); | |||
| out->set_data_type(in0->data_type()); | |||
| out->SetFormat(in0->GetFormat()); | |||
| if (1 < outputs.size()) { | |||
| auto *grads = outputs.at(1); | |||
| @@ -28,6 +28,29 @@ std::vector<int> Tile::GetDims() const { return this->primitive_->value.AsTile() | |||
| void Tile::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsTile()->dims = dims; } | |||
| int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Tile; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Tile) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::TileT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| std::vector<int> Tile::GetMultiples() const { | |||
| @@ -34,6 +34,7 @@ class Tile : public PrimitiveC { | |||
| explicit Tile(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetMultiples(const std::vector<int> &multiples); | |||
| void SetDims(const std::vector<int> &dims); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Tile() = default; | |||
| @@ -45,17 +45,8 @@ int ApplyMomentumCPUKernel::Run() { | |||
| float moment = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0]; | |||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||
| // align format | |||
| if (in_tensors_[3]->shape().size() == 4 && in_tensors_[3]->GetFormat() == schema::Format_NCHW && | |||
| in_tensors_[0]->GetFormat() == schema::Format_KHWC) { | |||
| PackNCHWToNHWCFp32(gradient, workspace, in_tensors_[0]->Batch(), in_tensors_[0]->Height() * in_tensors_[0]->Width(), | |||
| in_tensors_[0]->Channel()); | |||
| } else { | |||
| memcpy(workspace, gradient, in_tensors_[3]->ElementsNum() * sizeof(float)); | |||
| } | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| accumulate[i] = accumulate[i] * moment + workspace[i]; // * (1.0 - moment); | |||
| accumulate[i] = accumulate[i] * moment + gradient[i]; // * (1.0 - moment); | |||
| weight[i] -= accumulate[i] * learning_rate; | |||
| } | |||
| return RET_OK; | |||
| @@ -67,12 +58,7 @@ int ApplyMomentumCPUKernel::Init() { | |||
| auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||
| for (size_t i = 0; i < elem_num; i++) accumulate[i] = 0.0; | |||
| workspace = new float[elem_num]; | |||
| if (workspace == nullptr) { | |||
| MS_LOG(ERROR) << "apply momentum workspace fail to malloc!"; | |||
| return RET_ERROR; | |||
| } | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_APPLY_MOMENTUM_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_APPLY_MOMENTUM_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| @@ -27,20 +27,12 @@ class ApplyMomentumCPUKernel : public LiteKernel { | |||
| explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} | |||
| ~ApplyMomentumCPUKernel() override { | |||
| if (workspace) | |||
| delete[] workspace; | |||
| } | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~ApplyMomentumCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| float *workspace; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_ | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_APPLY_MOMENTUM_H_ | |||
| @@ -42,4 +42,4 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ | |||
| @@ -1,68 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp32_grad/depend.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Depend; | |||
| namespace mindspore::kernel { | |||
| int DependCPUKernel::Init() { return RET_OK; } | |||
| int DependCPUKernel::ReSize() { return 0; } | |||
| int DependCPUKernel::Run() { | |||
| // auto ret = Prepare(); | |||
| // if (ret != RET_OK) { | |||
| // MS_LOG(ERROR) << "Prepare failed."; | |||
| // return RET_ERROR; | |||
| // } | |||
| // auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| // auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| // | |||
| // memcpy(out, in, in_tensors_.at(0)->Size()); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuDependFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Depend); | |||
| auto *kernel = new (std::nothrow) DependCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| MS_ASSERT(kernel != nullptr); | |||
| auto ret = kernel->Init(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Depend, CpuDependFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -1,46 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "ir/anf.h" | |||
| #include "nnacl/fp32/arithmetic.h" | |||
| namespace mindspore::kernel { | |||
| class DependCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit DependCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| param = parameter; | |||
| } | |||
| ~DependCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| OpParameter *param; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_ | |||
| @@ -0,0 +1,111 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/primitive_c.h" | |||
| #include "include/train_model.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/graph_util.h" | |||
| namespace mindspore::lite { | |||
| bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model); | |||
| bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model); | |||
| TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||
| if (model_buf == nullptr) { | |||
| MS_LOG(ERROR) << "The model buf is nullptr"; | |||
| return nullptr; | |||
| } | |||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | |||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | |||
| return nullptr; | |||
| } | |||
| TrainModel *model = new (std::nothrow) TrainModel(); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "new model fail!"; | |||
| return nullptr; | |||
| } | |||
| model->buf = reinterpret_cast<char *>(malloc(size)); | |||
| if (model->buf == nullptr) { | |||
| MS_LOG(ERROR) << "new inner model buf fail!"; | |||
| return nullptr; | |||
| } | |||
| memcpy(model->buf, model_buf, size); | |||
| model->buf_size_ = size; | |||
| auto meta_graph = schema::GetMetaGraph(model->buf); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "meta_graph is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (meta_graph->name() != nullptr) { | |||
| model->name_ = meta_graph->name()->c_str(); | |||
| } | |||
| if (meta_graph->version() != nullptr) { | |||
| model->version_ = meta_graph->version()->c_str(); | |||
| } | |||
| auto in_count = meta_graph->inputIndex()->size(); | |||
| for (uint32_t i = 0; i < in_count; ++i) { | |||
| model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i))); | |||
| } | |||
| auto out_count = meta_graph->outputIndex()->size(); | |||
| for (uint32_t i = 0; i < out_count; ++i) { | |||
| model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i))); | |||
| } | |||
| if (!ConvertNodes(meta_graph, model)) { | |||
| delete model; | |||
| return nullptr; | |||
| } | |||
| if (!ConvertTensors(meta_graph, model)) { | |||
| delete model; | |||
| return nullptr; | |||
| } | |||
| return model; | |||
| } | |||
| void TrainModel::Free() { | |||
| } | |||
| char* TrainModel::ExportBuf(char* buffer, size_t* len) const { | |||
| MS_EXCEPTION_IF_NULL(len); | |||
| if (buf_size_ == 0 || buf == nullptr) { | |||
| MS_LOG(ERROR) << "Model::Export is only available for Train Session"; | |||
| return nullptr; | |||
| } | |||
| if (*len < buf_size_ && buffer != nullptr) { | |||
| MS_LOG(ERROR) << "Buffer is too small, Export Failed"; | |||
| return nullptr; | |||
| } | |||
| if (buffer == nullptr) | |||
| buffer = reinterpret_cast<char *>(malloc(buf_size_)); | |||
| if (buffer == nullptr) { | |||
| MS_LOG(ERROR) << "allocated model buf fail!"; | |||
| return nullptr; | |||
| } | |||
| memcpy(buffer, buf, buf_size_); | |||
| *len = buf_size_; | |||
| return buffer; | |||
| } | |||
| TrainModel::~TrainModel() { | |||
| Model::Free(); | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -18,8 +18,10 @@ | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "include/context.h" | |||
| #include "include/train_model.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/utils.h" | |||
| #include "mindspore/lite/src/tensor.h" | |||
| #include "src/tensor.h" | |||
| #include "src/train/loss_kernel.h" | |||
| #include "src/train/train_populate_parameter.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| @@ -29,6 +31,15 @@ | |||
| namespace mindspore::session { | |||
| static size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) { | |||
| for (size_t i = 0; i < where.size(); i++) { | |||
| if (where[i] == searchParameter) { | |||
| return i; | |||
| } | |||
| } | |||
| return where.size(); | |||
| } | |||
| TrainSession::TrainSession() { kernel::PopulateTrainParameters(); } | |||
| void TrainSession::ReplaceOps() { | |||
| @@ -37,118 +48,99 @@ void TrainSession::ReplaceOps() { | |||
| mindspore::kernel::CpuConvTrainFp32KernelCreator); | |||
| mindspore::lite::KernelRegistrar tmp0(mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, | |||
| mindspore::schema::PrimitiveType_DepthwiseConv2D, | |||
| mindspore::kernel::CpuConvTrainFp32KernelCreator); | |||
| mindspore::schema::PrimitiveType_DepthwiseConv2D, | |||
| mindspore::kernel::CpuConvTrainFp32KernelCreator); | |||
| } | |||
| int TrainSession::CompileGraph(lite::Model *model) { | |||
| model_ = model; | |||
| model_ = dynamic_cast<lite::TrainModel *>(model); | |||
| if (model_ == nullptr) { | |||
| MS_LOG(ERROR) << "TrainSession can only compile TrainModels"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| ReplaceOps(); | |||
| return LiteSession::CompileGraph(model); | |||
| auto ret = LiteSession::CompileGraph(model); | |||
| orig_output_map_ = output_node_map_; | |||
| orig_output_tensor_map_ = output_tensor_map_; | |||
| return ret; | |||
| } | |||
| TrainSession::~TrainSession() { | |||
| for (auto it1 = ext_output_map_.begin(); it1 != ext_output_map_.end(); ++it1) { | |||
| if ((output_node_map_.find(it1->first) == output_node_map_.end()) || train_mode_) { | |||
| // Delete if not from output_node_map_ | |||
| auto tensor_ptr = it1->second.back(); | |||
| delete tensor_ptr; | |||
| it1->second.pop_back(); | |||
| } | |||
| } | |||
| } | |||
| TrainSession::~TrainSession() { delete model_; } | |||
| void *TrainSession::ExportToBuf(lite::Model *model, void *buf, size_t *len) const { | |||
| return nullptr; | |||
| } | |||
| void *TrainSession::ExportToBuf(char *buf, size_t *len) const { return model_->ExportBuf(buf, len); } | |||
| int TrainSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { | |||
| auto ms_output_tensors = GetOutputMap(); | |||
| this->outputs_.clear(); | |||
| for (auto ms_tensors : ms_output_tensors) | |||
| for (auto ms_tensors : output_node_map_) | |||
| for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((dynamic_cast<lite::Tensor *>(ms_tensor))); | |||
| if (train_mode_) return LiteSession::RunGraph(before, after); | |||
| // object is expected to run only inference part of graph | |||
| // prepare a list of kernels till the loss function -- temporary solution | |||
| std::vector<kernel::LiteKernel *> infference_kernels; | |||
| std::vector<kernel::LiteKernel *> inference_kernels; | |||
| for (auto kernel : this->kernels_) { | |||
| if (dynamic_cast<const kernel::LossKernel *>(kernel) != nullptr) break; | |||
| infference_kernels.push_back(kernel); | |||
| inference_kernels.push_back(kernel); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(this->context_); | |||
| lite::Executor executor; | |||
| if (before == nullptr && after == nullptr) { | |||
| return executor.Run(this->inputs_, this->outputs_, infference_kernels, this->context_->allocator.get()); | |||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get()); | |||
| } else { | |||
| return executor.Run(this->inputs_, this->outputs_, infference_kernels, this->context_->allocator.get(), before, | |||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get(), before, | |||
| after); | |||
| } | |||
| } | |||
| void TrainSession::train() { | |||
| void TrainSession::Train() { | |||
| for (auto *kernel : kernels_) { | |||
| MS_ASSERT(nullptr != kernel); | |||
| kernel->train(); | |||
| } | |||
| for (auto it1 = ext_output_map_.begin(); it1 != ext_output_map_.end(); ++it1) { | |||
| if ((output_node_map_.find(it1->first) == output_node_map_.end()) || train_mode_) { | |||
| // Delete if not from output_node_map_ | |||
| auto tensor_ptr = it1->second.back(); | |||
| delete tensor_ptr; | |||
| it1->second.pop_back(); | |||
| } | |||
| } | |||
| ext_output_map_.clear(); | |||
| output_node_map_.clear(); | |||
| output_tensor_map_.clear(); | |||
| train_mode_ = true; | |||
| for (auto kernel : this->kernels_) { | |||
| if (dynamic_cast<const kernel::LossKernel *>(kernel) != nullptr) { | |||
| auto *ms_tensor = new lite::Tensor(*kernel->out_tensors().at(0)); | |||
| ext_output_map_[kernel->name()].emplace_back(ms_tensor); | |||
| auto *ms_tensor = kernel->out_tensors().at(0); | |||
| if (ms_tensor != nullptr) { | |||
| output_node_map_[kernel->name()].emplace_back(ms_tensor); | |||
| auto index = TSFindTensor(tensors_, ms_tensor); | |||
| if (index != tensors_.size()) { | |||
| output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void TrainSession::eval() { | |||
| for (auto *kernel : kernels_) { | |||
| void TrainSession::Eval() { | |||
| for (auto *kernel : this->kernels_) { | |||
| MS_ASSERT(nullptr != kernel); | |||
| kernel->eval(); | |||
| } | |||
| kernel::LiteKernel *last_kernel = nullptr; | |||
| for (auto it1 = ext_output_map_.begin(); it1 != ext_output_map_.end(); ++it1) { | |||
| if ((output_node_map_.find(it1->first) == output_node_map_.end()) || train_mode_) { | |||
| // Delete if not from output_node_map_ | |||
| auto tensor_ptr = it1->second.back(); | |||
| delete tensor_ptr; | |||
| it1->second.pop_back(); | |||
| } | |||
| } | |||
| ext_output_map_ = output_node_map_; | |||
| output_node_map_ = orig_output_map_; | |||
| output_tensor_map_ = orig_output_tensor_map_; | |||
| train_mode_ = false; | |||
| for (auto kernel : this->kernels_) { | |||
| if ((dynamic_cast<const kernel::LossKernel *>(kernel) != nullptr) && (last_kernel != nullptr)) { | |||
| if (ext_output_map_.find(last_kernel->name()) == ext_output_map_.end()) { | |||
| auto *ms_tensor = new lite::Tensor(*last_kernel->out_tensors().at(0)); | |||
| ext_output_map_[last_kernel->name()].emplace_back(ms_tensor); | |||
| if (output_node_map_.find(last_kernel->name()) == output_node_map_.end()) { | |||
| auto *ms_tensor = last_kernel->out_tensors().at(0); | |||
| if (ms_tensor != nullptr) { | |||
| output_node_map_[last_kernel->name()].emplace_back(ms_tensor); | |||
| auto index = TSFindTensor(tensors_, ms_tensor); | |||
| if (index != tensors_.size()) { | |||
| output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| last_kernel = kernel; | |||
| } | |||
| } | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> TrainSession::GetOutputMap() const { | |||
| return ext_output_map_; | |||
| } | |||
| std::vector<tensor::MSTensor *> TrainSession::GetOutputsByName(const std::string &name) const { | |||
| auto ret_vect = LiteSession::GetOutputsByNodeName(name); // TODO(emir): GetOutputsByTensorName? | |||
| if (ret_vect.size() > 0) return ret_vect; | |||
| auto ret = ext_output_map_.find(name); | |||
| if (ret == ext_output_map_.end()) { | |||
| MS_LOG(WARNING) << "Node " << name << " is not an output node"; | |||
| std::vector<mindspore::tensor::MSTensor *> empty_ret; | |||
| return empty_ret; | |||
| } | |||
| return ret->second; | |||
| } | |||
| } // namespace mindspore::session | |||
| @@ -214,28 +214,8 @@ if (SUPPORT_TRAIN) | |||
| # ${LITE_DIR}/src/train/ops/train_ops.cc | |||
| ${LITE_DIR}/src/train/train_populate_parameter.cc | |||
| ${LITE_DIR}/src/train/train_session.cc | |||
| ${LITE_DIR}/src/train/train_model.cc | |||
| ${LITE_DIR}/src/lite_session.cc | |||
| # ${SRC_DIR}/common/trans.cc | |||
| # ${SRC_DIR}/common/lite/trans_extends.cc | |||
| # ${SRC_DIR}/kernel/kernel_build_info.cc | |||
| # ${SRC_DIR}/utils/lite/base_ref_utils.cc | |||
| # ${SRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc | |||
| # ${SRC_DIR}/session/lite/session_basic_extends.cc | |||
| # ${SRC_DIR}/session/anf_runtime_algorithm.cc | |||
| # ${SRC_DIR}/session/anf_runtime_algorithm.cc | |||
| # ${SRC_DIR}/session/session_basic.cc | |||
| # ${SRC_DIR}/session/kernel_graph.cc | |||
| # ${SRC_DIR}/session/session_factory.cc | |||
| # ${SRC_DIR}/device/kernel_info.cc | |||
| # ${SRC_DIR}/device/kernel_runtime.cc | |||
| # ${SRC_DIR}/device/lite/kernel_runtime_extends.cc | |||
| # ${LITE_DIR}/src/common/anf_importer/anf_importer.cc | |||
| # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc | |||
| # ${LITE_DIR}/src/ir/primitive_value.cc | |||
| # ${LITE_DIR}/src/train/lite_kernel_runtime.cc | |||
| # ${LITE_DIR}/src/train/train_session.cc | |||
| # ${LITE_DIR}/src/train/model.cc | |||
| ${LITE_DIR}/src/lite_session.cc # temporary | |||
| ) | |||
| else() | |||
| set(TEST_LITE_SRC | |||
| @@ -21,8 +21,8 @@ | |||
| #include "src/common/file_utils_ext.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" | |||
| #include "nnacl/fp32_grad/batch_norm.h" | |||
| #include "nnacl/fp32/batchnorm.h" | |||
| #include "src/kernel_registry.h" | |||
| # | |||
| namespace mindspore { | |||
| @@ -43,7 +43,7 @@ lite::Tensor *TestBNGradFp32::CreateInTensor(std::string file_name, std::vector< | |||
| TEST_F(TestBNGradFp32, BNGradFp32) { | |||
| // prepare stage | |||
| auto bn_param = static_cast<BNGradParameter*>(malloc(sizeof(BNGradParameter))); | |||
| auto bn_param = static_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter))); | |||
| bn_param->epsilon_ = 0.00001; | |||
| bn_param->momentum_ = 0.1; | |||
| const int batch = 2; | |||
| @@ -111,17 +111,85 @@ TEST_F(TestBNGradFp32, BNGradFp32) { | |||
| MS_LOG(INFO) << "BNGradFp32 passed"; | |||
| } | |||
| #if 0 | |||
| TEST_F(TestBNGradFp32, BNTtrainFp32) { | |||
| auto bn_param = static_cast<BNGradParameter*>(malloc(sizeof(BNGradParameter))); | |||
| auto bn_param = static_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||
| bn_param->epsilon_ = 0.00001; | |||
| bn_param->momentum_ = 0.1; | |||
| const int batch = 2; | |||
| const int channels = 3; | |||
| const int height = 4; | |||
| const int width = 5; | |||
| bn_param->channel_ = channels; | |||
| auto x_tensor = CreateInTensor("./test_data/bngrad/input_x_2_4_5_3.bin", {batch, height, width, channels}); | |||
| std::vector<lite::Tensor *> inputs = {x_tensor, x_tensor, scale_tensor, mean_tensor, var_tensor}; | |||
| lite::Tensor scale_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| scale_tensor.MallocData(); | |||
| auto scale = reinterpret_cast<float *>(scale_tensor.MutableData()); | |||
| std::fill(scale, scale + channels, 1.0f); | |||
| lite::Tensor bias_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| bias_tensor.MallocData(); | |||
| auto bias = reinterpret_cast<float *>(bias_tensor.MutableData()); | |||
| std::fill(bias, bias + channels, 1.0f); | |||
| lite::Tensor mean_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| mean_tensor.MallocData(); | |||
| auto mean = reinterpret_cast<float *>(mean_tensor.MutableData()); | |||
| std::fill(mean, mean + channels, 0.0f); | |||
| lite::Tensor var_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| var_tensor.MallocData(); | |||
| auto var = reinterpret_cast<float *>(var_tensor.MutableData()); | |||
| std::fill(var, var + channels, 1.0f); | |||
| std::vector<lite::Tensor *> inputs = {x_tensor, &scale_tensor, &bias_tensor, &mean_tensor, &var_tensor}; | |||
| lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, {batch, height, width, channels}); | |||
| ASSERT_EQ(out_tensor.MallocData(), 0); | |||
| lite::Tensor run_mean_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(run_mean_tensor.MallocData(), 0); | |||
| lite::Tensor run_var_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(run_var_tensor.MallocData(), 0); | |||
| lite::Tensor save_mean_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(save_mean_tensor.MallocData(), 0); | |||
| lite::Tensor save_var_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(save_var_tensor.MallocData(), 0); | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor, &run_mean_tensor, &run_var_tensor, &save_mean_tensor, | |||
| &save_var_tensor}; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_FusedBatchNorm}; | |||
| mindspore::lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bn_param), &context, desc, nullptr); | |||
| kernel_obj->train(); | |||
| kernel_obj->Run(); | |||
| float *run_mean = reinterpret_cast<float *>(run_mean_tensor.MutableData()); | |||
| float *run_var = reinterpret_cast<float *>(run_var_tensor.MutableData()); | |||
| std::cout << "================run_mean==============================\n"; | |||
| for (int i = 0; i < channels; i++) std::cout << run_mean[i] << " "; | |||
| std::cout << "\n"; | |||
| std::cout << "================run_var==============================\n"; | |||
| for (int i = 0; i < channels; i++) std::cout << run_var[i] << " "; | |||
| std::cout << "\n"; | |||
| delete[] reinterpret_cast<float *>(x_tensor->MutableData()); | |||
| auto res = mindspore::lite::CompareRelativeOutput(run_mean, "./test_data/bngrad/running_mean_3.bin"); | |||
| EXPECT_EQ(res, 0); | |||
| res = mindspore::lite::CompareRelativeOutput(run_var, "./test_data/bngrad/running_var_3.bin"); | |||
| EXPECT_EQ(res, 0); | |||
| x_tensor->SetData(nullptr); | |||
| delete x_tensor; | |||
| delete kernel_obj; | |||
| } | |||
| #endif | |||
| } // namespace mindspore | |||
| @@ -515,10 +515,6 @@ TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { | |||
| auto *kernel = new mindspore::kernel::ConvolutionTrainCPUKernel(reinterpret_cast<OpParameter *>(conv_param), inputs, | |||
| outputs, &context, 0); | |||
| kernel->Init(); | |||
| // kernel::KernelKey desc = {kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2D}; | |||
| // auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc); | |||
| // auto kernel = creator(inputs, outputs, (OpParameter *)conv_param, &context, desc); | |||
| kernel->train(); | |||
| EXPECT_EQ(kernel->is_train(), 1); | |||
| @@ -23,7 +23,7 @@ | |||
| #include <functional> | |||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||
| #include "mindspore/lite/include/model.h" | |||
| #include "mindspore/lite/include/train_model.h" | |||
| #include "common/common_test.h" | |||
| #include "include/train_session.h" | |||
| // #include "include/lite_session.h" | |||
| @@ -249,13 +249,6 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| label->dataType = TypeId::kNumberTypeInt32; | |||
| label->dims = {BATCH_SIZE * NUM_CLASSES}; | |||
| label->offset = -1; | |||
| // label->data.resize(BATCH_SIZE * NUM_CLASSES * sizeof(float)); | |||
| // int *data = reinterpret_cast<int *>(label->data.data()); | |||
| // for (int i = 0; i < BATCH_SIZE; i++) { | |||
| // for (int j = 0; j < NUM_CLASSES; j++) { | |||
| // *(data + i * NUM_CLASSES + j) = j; | |||
| // } | |||
| // } | |||
| meta_graph->allTensors.emplace_back(std::move(label)); | |||
| } | |||
| // tensor 7 - Softmaxentropy | |||
| @@ -353,20 +346,9 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| builder.Finish(offset); | |||
| size_t size = builder.GetSize(); | |||
| const char *content = reinterpret_cast<char *>(builder.GetBufferPointer()); | |||
| std::cout << "build fb size= " << size << "\n"; | |||
| std::cout << "build fb size= " << size << std::endl; | |||
| #if 0 // EXPORT_FILE | |||
| std::string path = std::string("hcdemo_train.fb"); | |||
| std::ofstream ofs(path); | |||
| ASSERT_EQ(true, ofs.good()); | |||
| ASSERT_EQ(true, ofs.is_open()); | |||
| ofs.seekp(0, std::ios::beg); | |||
| ofs.write(content, size); | |||
| ofs.close(); | |||
| #endif | |||
| auto model = lite::Model::Import(content, size); | |||
| auto model = lite::TrainModel::Import(content, size); | |||
| ASSERT_NE(nullptr, model); | |||
| meta_graph.reset(); | |||
| content = nullptr; | |||
| @@ -380,8 +362,8 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| session->Init(&context); | |||
| auto ret = session->CompileGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| session->train(); | |||
| session->train(); // Just double check that calling train twice does not cause a problem | |||
| session->Train(); | |||
| session->Train(); // Just double check that calling Train twice does not cause a problem | |||
| auto inputs = session->GetInputs(); | |||
| ASSERT_EQ(inputs.size(), 2); | |||
| @@ -407,22 +389,20 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| ret = session->RunGraph(); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| auto outputs = session->GetOutputsByName("BiasAdd"); | |||
| auto outputs = session->GetOutputsByNodeName("SoftmaxCrossEntropy"); | |||
| ASSERT_EQ(outputs.size(), 1); | |||
| auto outTensor = (outputs.at(0)); | |||
| ASSERT_NE(nullptr, outTensor); | |||
| ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); | |||
| auto *outData = reinterpret_cast<float *>(outTensor->MutableData()); | |||
| ASSERT_NE(nullptr, outData); | |||
| std::cout << "==============Initial=Scores===================" << std::endl; | |||
| for (int i = 0; i < 10; i++) { | |||
| std::cout << outData[i] << ", "; | |||
| } | |||
| std::cout << std::endl; | |||
| session->eval(); | |||
| session->eval(); // Just double check that calling eval twice does not cause a problem | |||
| std::cout << "==============Initial=Loss=====================" << std::endl; | |||
| std::cout << outData[0] << ", " << std::endl; | |||
| session->Eval(); | |||
| session->Eval(); // Just double check that calling eval twice does not cause a problem | |||
| ret = session->RunGraph(); | |||
| outputs = session->GetOutputsByName("BiasAdd"); | |||
| outputs = session->GetOutputsByNodeName("BiasAdd"); | |||
| ASSERT_EQ(outputs.size(), 1); | |||
| outTensor = (outputs.at(0)); | |||
| ASSERT_NE(nullptr, outTensor); | |||
| @@ -433,14 +413,14 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| for (int i = 0; i < 10; i++) { | |||
| std::cout << outData[i] << ", "; | |||
| } | |||
| std::cout << std::endl; | |||
| std::string output_path = "./test_data/train/train_output_32_10.bin"; | |||
| auto error = lite::RelativeOutputError(outData, output_path); | |||
| EXPECT_LT(error, 2e-3); | |||
| ret = session->RunGraph(); | |||
| outputs = session->GetOutputsByName("BiasAdd"); | |||
| ASSERT_EQ(outputs.size(), 1); | |||
| outTensor = (outputs.at(0)); | |||
| auto all_output_tensors = session->GetOutputs(); | |||
| outTensor = (all_output_tensors["5"]); | |||
| ASSERT_NE(nullptr, outTensor); | |||
| ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); | |||
| outData = reinterpret_cast<float *>(outTensor->MutableData()); | |||
| @@ -449,15 +429,14 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| for (int i = 0; i < 10; i++) { | |||
| std::cout << outData[i] << ", "; | |||
| } | |||
| std::cout << std::endl; | |||
| error = lite::RelativeOutputError(outData, output_path); | |||
| EXPECT_LT(error, 2e-3); | |||
| session->train(); | |||
| session->eval(); // do some more zig-zags | |||
| session->Train(); | |||
| session->Eval(); // do some more zig-zags | |||
| ret = session->RunGraph(); | |||
| outputs = session->GetOutputsByName("BiasAdd"); | |||
| ASSERT_EQ(outputs.size(), 1); | |||
| outTensor = (outputs.at(0)); | |||
| outTensor = session->GetOutputByTensorName("5"); | |||
| ASSERT_NE(nullptr, outTensor); | |||
| ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); | |||
| outData = reinterpret_cast<float *>(outTensor->MutableData()); | |||
| @@ -466,10 +445,10 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| for (int i = 0; i < 10; i++) { | |||
| std::cout << outData[i] << ", "; | |||
| } | |||
| std::cout << std::endl; | |||
| error = lite::RelativeOutputError(outData, output_path); | |||
| EXPECT_LT(error, 2e-3); | |||
| delete model; | |||
| delete session; | |||
| MS_LOG(INFO) << "TuningLayer passed"; | |||
| } | |||
| @@ -490,19 +469,16 @@ int32_t fileIterator(mindspore::session::TrainSession *session, const std::strin | |||
| } | |||
| void replaceExt(const std::string &src, std::string *dst) { *dst = src.substr(0, src.find_last_of('.')) + ".emb"; } | |||
| int32_t runEffNet(mindspore::lite::LiteSession *session, const std::string &in, const std::string &out) { | |||
| int32_t runNet(mindspore::lite::LiteSession *session, const std::string &in, const std::string &out, | |||
| const char *tensor_name) { | |||
| // setup input | |||
| auto inputs = session->GetInputs(); | |||
| // ASSERT_EQ(inputs.size(), 1); | |||
| auto inTensor = inputs.at(0); | |||
| // ASSERT_NE(nullptr, inTensor); | |||
| float *data = reinterpret_cast<float *>(inTensor->MutableData()); | |||
| size_t input_size; | |||
| float *in_buf = reinterpret_cast<float *>(lite::ReadFile(in.c_str(), &input_size)); | |||
| // ASSERT_NE(nullptr, data); | |||
| auto input_data = reinterpret_cast<float *>(in_buf); | |||
| // ASSERT_EQ(input_size, inTensor->Size()); | |||
| std::copy(input_data, input_data + inTensor->ElementsNum(), data); | |||
| delete[] in_buf; | |||
| @@ -510,11 +486,10 @@ int32_t runEffNet(mindspore::lite::LiteSession *session, const std::string &in, | |||
| session->RunGraph(); | |||
| // compare outputs | |||
| auto outputs = session->GetOutputs(); | |||
| auto output = ((outputs.begin())->second); | |||
| auto output = session->GetOutputByTensorName(tensor_name); | |||
| float *output_data = reinterpret_cast<float *>(output->MutableData()); | |||
| return mindspore::lite::CompareRelativeOutput(output_data, out.c_str()); | |||
| return mindspore::lite::CompareRelativeOutput(output_data, out); | |||
| } | |||
| TEST_F(NetworkTest, efficient_net) { | |||
| @@ -524,7 +499,7 @@ TEST_F(NetworkTest, efficient_net) { | |||
| std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms"; | |||
| ReadFile(net.c_str(), &net_size, &buf); | |||
| auto model = lite::Model::Import(buf, net_size); | |||
| auto model = lite::TrainModel::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::InnerContext; | |||
| context->device_type_ = lite::DT_CPU; | |||
| @@ -533,60 +508,47 @@ TEST_F(NetworkTest, efficient_net) { | |||
| ASSERT_EQ(lite::RET_OK, context->Init()); | |||
| auto session = new mindspore::session::TrainSession(); | |||
| // auto session = new mindspore::lite::LiteSession(); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->Init(context); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| ret = session->CompileGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| session->eval(); | |||
| #if 0 | |||
| std::string path = "/opt/share/MiniBinEmbDataset/"; | |||
| auto res = fileIterator(session, path, [](mindspore::lite::LiteSession *session, const std::string &in) { | |||
| int32_t res = 0; | |||
| if (in.find(".bin") != std::string::npos) { | |||
| std::string out; | |||
| replaceExt(in, &out); | |||
| res = runEffNet(session, in, out); | |||
| std::cout << "input file: " << in << (res ? " Fail" : " Pass") << std::endl; | |||
| } | |||
| return res; | |||
| }); | |||
| #else | |||
| session->Eval(); | |||
| std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; | |||
| std::string out = "./test_data/nets/effNet_output_y_1_1000.bin"; | |||
| auto res = runEffNet(session, in, out); | |||
| #endif | |||
| // auto inputs = session->GetInputs(); | |||
| // ASSERT_EQ(inputs.size(), NUM_OF_INPUTS); | |||
| // auto inTensor = inputs.at(0); | |||
| // ASSERT_NE(nullptr, inTensor); | |||
| // float *data = reinterpret_cast<float *>(inTensor->MutableData()); | |||
| // // fill input | |||
| // std::string input_path = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; | |||
| // // std::string input_path = "/opt/share/MiniBinEmbDataset/2_pet/n02099601_3111.bin"; | |||
| // size_t input_size; | |||
| // char *in_buf = nullptr; | |||
| // ReadFile(input_path.c_str(), &input_size, &in_buf); | |||
| // ASSERT_NE(nullptr, data); | |||
| // auto input_data = reinterpret_cast<float *>(in_buf); | |||
| // ASSERT_EQ(input_size, inTensor->Size()); | |||
| // std::copy(input_data, input_data+inTensor->ElementsNum(), data); | |||
| // // execute network | |||
| // ret = session->RunGraph(); | |||
| // // compare outputs | |||
| // std::string output_path = "./test_data/nets/effNet_output_y_1_1000.bin"; | |||
| // // std::string output_path = "/opt/share/MiniBinEmbDataset/2_pet/n02099601_3111.emb"; | |||
| // auto outputs = session->GetOutputs(); | |||
| // auto output = ((outputs.begin())->second); | |||
| // float* output_data = reinterpret_cast<float *>(output.at(0)->MutableData()); | |||
| // int res = lite::CompareRelativeOutput(output_data, output_path); | |||
| auto res = runNet(session, in, out, "631"); | |||
| ASSERT_EQ(res, 0); | |||
| delete session; | |||
| delete context; | |||
| } | |||
| TEST_F(NetworkTest, lenetnet) { | |||
| char *buf = nullptr; | |||
| size_t net_size = 0; | |||
| std::string net = "./test_data/nets/lenet_train.ms"; | |||
| ReadFile(net.c_str(), &net_size, &buf); | |||
| auto model = lite::TrainModel::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| auto session = new mindspore::session::TrainSession(); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->Init(context); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| ret = session->CompileGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| session->Eval(); | |||
| std::string in = "./test_data/nets/x_lenet.bin"; | |||
| std::string out = "./test_data/nets/y_lenet.bin"; | |||
| auto res = runNet(session, in, out, "24"); | |||
| ASSERT_EQ(res, 0); | |||
| delete model; | |||
| delete session; | |||
| delete context; | |||
| } | |||
| @@ -196,7 +196,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) { | |||
| std::vector<int> dim_dx({3, 28, 28, 3}); | |||
| lite::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||
| dx_tensor.MallocData(); | |||
| ASSERT_EQ(dx_tensor.MallocData(), 0); | |||
| auto output_data = reinterpret_cast<float *>(dx_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor}; | |||
| @@ -253,7 +253,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride2Fp32) { | |||
| lite::Tensor yt_tensor(TypeId::kNumberTypeFloat32, dim_y); | |||
| yt_tensor.SetData(yt_data); | |||
| lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| out_tensor.MallocData(); | |||
| ASSERT_EQ(out_tensor.MallocData(), 0); | |||
| float *out_data = static_cast<float *>(out_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> inputs = {&yt_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||
| @@ -312,7 +312,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride3Fp32) { | |||
| yt_tensor.SetData(yt_data); | |||
| lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| out_tensor.MallocData(); | |||
| ASSERT_EQ(out_tensor.MallocData(), 0); | |||
| auto out_data = static_cast<float *>(out_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> inputs = {&yt_tensor, &x_tensor}; | |||
| @@ -399,124 +399,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { | |||
| MS_LOG(INFO) << "TestMaxPoolingGradFp32 passed"; | |||
| } | |||
| #if 0 | |||
| TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) { | |||
| // prepare stage | |||
| auto maxpool = new PoolingParameter(); | |||
| InitPoolingParamFP32(maxpool); | |||
| maxpool->pool_mode_ = PoolMode_MaxPool; | |||
| maxpool->input_h_ = 30; | |||
| maxpool->input_w_ = 30; | |||
| maxpool->input_channel_ = 3; | |||
| maxpool->output_batch_ = 1; | |||
| maxpool->output_h_ = 10; | |||
| maxpool->output_w_ = 10; | |||
| maxpool->output_channel_ = 3; | |||
| maxpool->stride_h_ = 3; | |||
| maxpool->stride_w_ = 3; | |||
| maxpool->pad_u_ = 0; | |||
| maxpool->pad_d_ = 0; | |||
| maxpool->pad_l_ = 0; | |||
| maxpool->pad_r_ = 0; | |||
| size_t input_size; | |||
| size_t y_data_size = maxpool->output_batch_ * maxpool->output_channel_ * maxpool->output_h_ * maxpool->output_w_; | |||
| auto x_data = reinterpret_cast<float *>( | |||
| mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_2_x_1_30_30_3.bin", &input_size)); | |||
| std::vector<int> dim_x({1, 30, 30, 3}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(x_data); | |||
| std::vector<lite::Tensor *> maxpool_inputs = {&x_tensor}; | |||
| auto y_data = new float[y_data_size]; | |||
| std::vector<int> dim_y({1, 10, 10, 3}); | |||
| lite::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); | |||
| y_tensor.SetData(y_data); | |||
| auto ind_data = new int[y_data_size]; | |||
| lite::Tensor ind_tensor(TypeId::kNumberTypeInt32, dim_y); | |||
| ind_tensor.SetData(ind_data); | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&y_tensor, &ind_tensor}; | |||
| kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Pooling}; | |||
| auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); | |||
| auto maxpoolobj = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), | |||
| NULL, maxpool_desc); | |||
| maxpoolobj->Run(); | |||
| printf("==================indices data=================\n"); | |||
| for (int i = 0; i < 10; i++) { | |||
| std::cout << ind_data[i] << " ,"; | |||
| } | |||
| std::cout << std::endl; | |||
| auto pooling_param = new PoolingParameter(); | |||
| InitPoolingParamFP32(pooling_param); | |||
| pooling_param->pool_mode_ = PoolMode_MaxPool; | |||
| pooling_param->input_h_ = 10; | |||
| pooling_param->input_w_ = 10; | |||
| pooling_param->input_channel_ = 3; | |||
| pooling_param->output_batch_ = 1; | |||
| pooling_param->output_h_ = 30; | |||
| pooling_param->output_w_ = 30; | |||
| pooling_param->output_channel_ = 3; | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| // uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; | |||
| auto dy_data = reinterpret_cast<float *>( | |||
| mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_2_dy_1_10_10_3.bin", &input_size)); | |||
| std::vector<int> dim_dy({1, 3, 10, 10}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| #if 0 | |||
| std::string i_path = "./test_data/pooling/maxpoolgradfp32_2_i_1_3_10_10.bin"; | |||
| auto ill_data = reinterpret_cast<int64_t*>(mindspore::lite::ReadFile(i_path.c_str(), &input_size)); | |||
| auto i_data = new int[output_data_size]; | |||
| for (int i=0; i < output_data_size; i++) | |||
| i_data[i] = static_cast<int>(ill_data[i]); | |||
| std::vector<int> dim_ind({1, 3, 10, 10}); | |||
| lite::Tensor ind_tensor(TypeId::kNumberTypeInt32, dim_ind); | |||
| ind_tensor.SetData(i_data); | |||
| #endif | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &ind_tensor}; | |||
| auto output_data = new float[output_data_size]; | |||
| std::vector<int> dim_dx({1, 3, 30, 30}); | |||
| lite::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||
| dx_tensor.SetData(output_data); | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor}; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(pooling_param), NULL, desc); | |||
| kernel_obj->Run(); | |||
| printf("==================output data=================\n"); | |||
| for (int i = 0; i < 20; i++) { | |||
| std::cout << output_data[i] << " ,"; | |||
| } | |||
| std::cout << std::endl; | |||
| std::string output_path = "./test_data/pooling/maxpoolgradfp32_2_dx_1_30_30_3.bin"; | |||
| lite::CompareOutput(output_data, output_path); | |||
| // delete input_data; | |||
| // delete[] output_data; | |||
| delete pooling_param; | |||
| MS_LOG(INFO) << "TestMaxPoolingKernelGradFp32 passed"; | |||
| } | |||
| #endif // if 0 before MaxPoolingKernelGradFp32 | |||
| TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) { | |||
| // prepare stage | |||
| // input size will be equal to the original size of x, output size will be the output size as in forward | |||
| @@ -547,7 +429,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) { | |||
| yt_tensor.SetData(yt_data); | |||
| lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| out_tensor.MallocData(); | |||
| ASSERT_EQ(out_tensor.MallocData(), 0); | |||
| auto out_data = static_cast<float *>(out_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor}; | |||
| @@ -616,7 +498,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride2Fp32) { | |||
| yt_tensor.SetData(yt_data); | |||
| lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| out_tensor.MallocData(); | |||
| ASSERT_EQ(out_tensor.MallocData(), 0); | |||
| auto out_data = static_cast<float *>(out_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; | |||
| @@ -686,7 +568,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride3Fp32) { | |||
| yt_tensor.SetData(yt_data); | |||
| lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| out_tensor.MallocData(); | |||
| ASSERT_EQ(out_tensor.MallocData(), 0); | |||
| auto out_data = static_cast<float *>(out_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; | |||
| @@ -0,0 +1 @@ | |||
| -<<= | |||
| @@ -0,0 +1 @@ | |||
| c箒?� �?;�[? | |||
| @@ -0,0 +1 @@ | |||
| R<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼR<h<ᨠ<0ڼrQ|;`ʼ | |||
| @@ -208,13 +208,17 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||
| break; | |||
| } | |||
| if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem || | |||
| primitive_c->Type() == schema::PrimitiveType_MakeTuple || | |||
| primitive_c->Type() == schema::PrimitiveType_Depend) { | |||
| primitive_c->Type() == schema::PrimitiveType_MakeTuple | |||
| #ifdef SUPPORT_TRAIN | |||
| || primitive_c->Type() == schema::PrimitiveType_Depend | |||
| #endif | |||
| ) { | |||
| continue; | |||
| } | |||
| RemoveIfMakeTuple(cnode); | |||
| #ifdef SUPPORT_TRAIN | |||
| RemoveIfDepend(cnode); | |||
| #endif | |||
| auto primT = primitive_c->GetPrimitiveT(); | |||
| auto node = std::make_unique<schema::CNodeT>(); | |||
| if (node == nullptr) { | |||
| @@ -380,6 +384,10 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||
| } else if (value->isa<mindspore::ValueSequeue>()) { | |||
| #ifndef SUPPORT_TRAIN | |||
| MS_LOG(DEBUG) << "Value type is ValueSequence."; | |||
| return RET_OK; | |||
| #else | |||
| auto valueAbstract = valueNode->abstract(); | |||
| auto abstractSequnce = utils::cast<abstract::AbstractSequeuePtr>(valueAbstract); | |||
| if (abstractSequnce->isa<abstract::AbstractTuple>()) { | |||
| @@ -410,22 +418,11 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| } else { | |||
| MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; | |||
| } | |||
| } else if (value->isa<mindspore::BoolImm>()) { | |||
| auto valueAbstract = valueNode->abstract(); | |||
| auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract); | |||
| auto typePtr = abstractScalar->GetTypeTrack(); | |||
| paramTensor->dataType = typePtr->type_id(); | |||
| paramTensor->dims = {1}; | |||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||
| auto data = value->cast<mindspore::BoolImmPtr>(); | |||
| paramTensor->data.emplace_back(data->value()); | |||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Not support value type , need add support."; | |||
| return RET_ERROR; | |||
| } | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "Not support value type , need add support."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -31,6 +31,9 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = { | |||
| schema::PrimitiveType_PoolingGrad, | |||
| schema::PrimitiveType_BiasGrad, | |||
| schema::PrimitiveType_BNGrad, | |||
| schema::PrimitiveType_ActivationGrad, | |||
| schema::PrimitiveType_ApplyMomentum, | |||
| #endif | |||
| schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DeConv2D, | |||
| @@ -51,7 +54,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = { | |||
| static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | |||
| #ifdef SUPPORT_TRAIN | |||
| schema::PrimitiveType_PoolingGrad | |||
| schema::PrimitiveType_PoolingGrad, | |||
| schema::PrimitiveType_ActivationGrad | |||
| #endif | |||
| }; | |||
| @@ -139,20 +139,16 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| return RET_ERROR; | |||
| } | |||
| STATUS status; | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) { | |||
| int idx_num = node->inputIndex.size(); | |||
| for (int i = 0; i < idx_num; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| int idx_num = node->inputIndex.size(); | |||
| for (int i = 0; i < idx_num; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else if (IsContain(GetNhwcDualInputOpList(), GetCNodeTType(**iter))) { | |||
| for (int i = 0; i < 2; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | |||
| @@ -162,12 +158,27 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| } | |||
| } else { | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); | |||
| int idx = 0; | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) | |||
| idx = 3; | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| #else | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| #endif | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||