| @@ -395,6 +395,7 @@ if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = " | |||||
| git submodule update --init --recursive akg | git submodule update --init --recursive akg | ||||
| fi | fi | ||||
| build_exit() | build_exit() | ||||
| { | { | ||||
| echo "$@" >&2 | echo "$@" >&2 | ||||
| @@ -596,33 +597,40 @@ build_lite() | |||||
| build_lite_java_arm64() { | build_lite_java_arm64() { | ||||
| # build mindspore-lite arm64 | # build mindspore-lite arm64 | ||||
| if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz" ]]; then | |||||
| JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch64 | |||||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||||
| JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch64 | |||||
| fi | |||||
| if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then | |||||
| build_lite "arm64" "off" | build_lite "arm64" "off" | ||||
| fi | fi | ||||
| # copy arm64 so | # copy arm64 so | ||||
| cd ${BASEPATH}/output/ | cd ${BASEPATH}/output/ | ||||
| rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64 | |||||
| tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz | |||||
| rm -rf ${JTARBALL} | |||||
| tar -zxvf ${JTARBALL}.tar.gz | |||||
| [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/ | [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/ | ||||
| mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/ | mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/ | ||||
| cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||||
| echo mindspore-lite-${VERSION_STR}-inference-android-aarch64 | |||||
| [ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64 | |||||
| cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||||
| [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL} | |||||
| } | } | ||||
| build_lite_java_arm32() { | build_lite_java_arm32() { | ||||
| # build mindspore-lite arm32 | # build mindspore-lite arm32 | ||||
| if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz" ]]; then | |||||
| JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch32 | |||||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||||
| JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch32 | |||||
| fi | |||||
| if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then | |||||
| build_lite "arm32" "off" | build_lite "arm32" "off" | ||||
| fi | fi | ||||
| # copy arm32 so | # copy arm32 so | ||||
| cd ${BASEPATH}/output/ | cd ${BASEPATH}/output/ | ||||
| rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32 | |||||
| tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz | |||||
| rm -rf ${JTARBALL} | |||||
| tar -zxvf ${JTARBALL}.tar.gz | |||||
| [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | ||||
| mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | ||||
| cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||||
| [ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32 | |||||
| cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||||
| [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL} | |||||
| } | } | ||||
| build_jni_arm64() { | build_jni_arm64() { | ||||
| @@ -635,7 +643,7 @@ build_jni_arm64() { | |||||
| -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ | -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ | ||||
| -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | ||||
| -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | ||||
| -DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native" | |||||
| -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native" | |||||
| make -j$THREAD_NUM | make -j$THREAD_NUM | ||||
| if [[ $? -ne 0 ]]; then | if [[ $? -ne 0 ]]; then | ||||
| echo "---------------- mindspore lite: build jni arm64 failed----------------" | echo "---------------- mindspore lite: build jni arm64 failed----------------" | ||||
| @@ -655,7 +663,7 @@ build_jni_arm32() { | |||||
| -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ | -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ | ||||
| -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | ||||
| -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | ||||
| -DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native" | |||||
| -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native" | |||||
| make -j$THREAD_NUM | make -j$THREAD_NUM | ||||
| if [[ $? -ne 0 ]]; then | if [[ $? -ne 0 ]]; then | ||||
| echo "---------------- mindspore lite: build jni arm32 failed----------------" | echo "---------------- mindspore lite: build jni arm32 failed----------------" | ||||
| @@ -3,7 +3,7 @@ APP:=bin/net_runner | |||||
| MSLIB:=mindspore-lite | MSLIB:=mindspore-lite | ||||
| MSDIR:=$(realpath package-$(TARGET)/lib) | MSDIR:=$(realpath package-$(TARGET)/lib) | ||||
| SRC:=src/net_runner.cc src/dataset.cc | |||||
| SRC:=src/net_runner.cc src/dataset.cc src/data_callbacks.cc | |||||
| OBJ:=$(SRC:.cc=.o) | OBJ:=$(SRC:.cc=.o) | ||||
| CFLAGS := -Ofast -std=c++17 \ | CFLAGS := -Ofast -std=c++17 \ | ||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_ | |||||
| #define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| #include "include/train/train_loop.h" | |||||
| #include "src/dataset.h" | |||||
| using GraphPoint = std::pair<int, float>; | |||||
| class AccuracyMonitor : public mindspore::session::TrainLoopCallBack { | |||||
| public: | |||||
| explicit AccuracyMonitor(DataSet *dataset, int check_every_n, int max_steps = -1) | |||||
| : ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {} | |||||
| int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override; | |||||
| const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; } | |||||
| private: | |||||
| DataSet *ds_; | |||||
| std::vector<GraphPoint> accuracies_; | |||||
| int check_every_n_; | |||||
| int max_steps_; | |||||
| }; | |||||
| #endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_ | |||||
| @@ -0,0 +1,103 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <math.h> | |||||
| #include <getopt.h> | |||||
| #include <cstring> | |||||
| #include <iostream> | |||||
| #include <fstream> | |||||
| #include <utility> | |||||
| #include "src/net_runner.h" | |||||
| #include "include/context.h" | |||||
| #include "src/utils.h" | |||||
| #include "src/data_loader.h" | |||||
| #include "src/accuracy_monitor.h" | |||||
| static unsigned int seed = time(NULL); | |||||
| std::vector<int> FillInputDataUtil(const mindspore::session::TrainLoopCallBackData &cb_data, | |||||
| const std::vector<DataLabelTuple> &dataset, bool serially) { | |||||
| static unsigned int idx = 1; | |||||
| int total_size = dataset.size(); | |||||
| std::vector<int> labels_vec; | |||||
| auto inputs = cb_data.session_->GetInputs(); | |||||
| char *input_data = reinterpret_cast<char *>(inputs.at(0)->MutableData()); | |||||
| auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData()); | |||||
| int batch_size = inputs.at(0)->shape()[0]; | |||||
| int num_of_classes = inputs.at(1)->shape()[1]; | |||||
| int data_size = inputs.at(0)->Size() / batch_size; | |||||
| MS_ASSERT(total_size > 0); | |||||
| MS_ASSERT(input_data != nullptr); | |||||
| std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f); | |||||
| for (int i = 0; i < batch_size; i++) { | |||||
| if (serially) { | |||||
| idx = ++idx % total_size; | |||||
| } else { | |||||
| idx = rand_r(&seed) % total_size; | |||||
| } | |||||
| int label = 0; | |||||
| char *data = nullptr; | |||||
| std::tie(data, label) = dataset[idx]; | |||||
| std::copy(data, data + data_size, input_data + i * data_size); | |||||
| labels[i * num_of_classes + label] = 1.0; // Model expects labels in onehot representation | |||||
| labels_vec.push_back(label); | |||||
| } | |||||
| return labels_vec; | |||||
| } | |||||
| void DataLoader::StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) { | |||||
| FillInputDataUtil(cb_data, ds_->train_data(), false); | |||||
| } | |||||
| int AccuracyMonitor::EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) { | |||||
| if ((cb_data.epoch_ + 1) % check_every_n_ != 0) return mindspore::session::RET_CONTINUE; | |||||
| float accuracy = 0.0; | |||||
| auto inputs = cb_data.session_->GetInputs(); | |||||
| int batch_size = inputs.at(0)->shape()[0]; | |||||
| int num_of_classes = ds_->num_of_classes(); | |||||
| int tests = ds_->test_data().size() / batch_size; | |||||
| if (max_steps_ != -1 && tests > max_steps_) tests = max_steps_; | |||||
| cb_data.session_->Eval(); | |||||
| for (int i = 0; i < tests; i++) { | |||||
| auto labels = FillInputDataUtil(cb_data, ds_->test_data(), false); | |||||
| cb_data.session_->RunGraph(); | |||||
| auto outputs = cb_data.session_->GetPredictions(); | |||||
| for (auto it = outputs.begin(); it != outputs.end(); ++it) { | |||||
| if (it->second->ElementsNum() == batch_size * num_of_classes) { | |||||
| auto scores = reinterpret_cast<float *>(it->second->MutableData()); | |||||
| for (int b = 0; b < batch_size; b++) { | |||||
| int max_idx = 0; | |||||
| float max_score = scores[num_of_classes * b]; | |||||
| for (int c = 1; c < num_of_classes; c++) { | |||||
| if (scores[num_of_classes * b + c] > max_score) { | |||||
| max_score = scores[num_of_classes * b + c]; | |||||
| max_idx = c; | |||||
| } | |||||
| } | |||||
| if (labels[b] == max_idx) accuracy += 1.0; | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| accuracy /= static_cast<float>(batch_size * tests); | |||||
| accuracies_.push_back(std::make_pair(cb_data.epoch_, accuracy)); | |||||
| std::cout << cb_data.epoch_ + 1 << ":\tAccuracy is " << accuracy << std::endl; | |||||
| cb_data.session_->Train(); | |||||
| return mindspore::session::RET_CONTINUE; | |||||
| } | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_ | |||||
| #define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| #include "include/train/train_loop.h" | |||||
| #include "src/dataset.h" | |||||
| class DataLoader : public mindspore::session::TrainLoopCallBack { | |||||
| public: | |||||
| explicit DataLoader(DataSet *dataset) : ds_(dataset) {} | |||||
| void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override; | |||||
| private: | |||||
| DataSet *ds_; | |||||
| }; | |||||
| #endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_ | |||||
| @@ -20,10 +20,21 @@ | |||||
| #include <cstring> | #include <cstring> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <utility> | |||||
| #include "include/context.h" | #include "include/context.h" | ||||
| #include "include/train/loss_monitor.h" | |||||
| #include "include/train/ckpt_saver.h" | |||||
| #include "include/train/lr_scheduler.h" | |||||
| #include "include/train/classification_train_accuracy_monitor.h" | |||||
| #include "src/utils.h" | #include "src/utils.h" | ||||
| #include "src/data_loader.h" | |||||
| #include "src/accuracy_monitor.h" | |||||
| using mindspore::session::TrainLoopCallBack; | |||||
| using mindspore::session::TrainLoopCallBackData; | |||||
| static unsigned int seed = time(NULL); | |||||
| unsigned int NetRunner::seed_ = time(NULL); | |||||
| // Definition of callback function after forwarding operator. | // Definition of callback function after forwarding operator. | ||||
| bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs, | bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs, | ||||
| const std::vector<mindspore::tensor::MSTensor *> &after_outputs, | const std::vector<mindspore::tensor::MSTensor *> &after_outputs, | ||||
| @@ -54,15 +65,18 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu | |||||
| } | } | ||||
| NetRunner::~NetRunner() { | NetRunner::~NetRunner() { | ||||
| if (session_ != nullptr) delete session_; | |||||
| if (loop_ != nullptr) delete loop_; | |||||
| } | } | ||||
| void NetRunner::InitAndFigureInputs() { | void NetRunner::InitAndFigureInputs() { | ||||
| mindspore::lite::Context context; | mindspore::lite::Context context; | ||||
| context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND; | context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND; | ||||
| context.thread_num_ = 1; | |||||
| context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = false; | |||||
| context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; | |||||
| context.thread_num_ = 2; | |||||
| session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context); | |||||
| loop_ = mindspore::session::TrainLoop::CreateTrainLoop(ms_file_, &context); | |||||
| session_ = loop_->train_session(); | |||||
| MS_ASSERT(nullptr != session_); | MS_ASSERT(nullptr != session_); | ||||
| auto inputs = session_->GetInputs(); | auto inputs = session_->GetInputs(); | ||||
| @@ -76,71 +90,10 @@ void NetRunner::InitAndFigureInputs() { | |||||
| } | } | ||||
| } | } | ||||
| mindspore::tensor::MSTensor *NetRunner::SearchOutputsForSize(size_t size) const { | |||||
| auto outputs = session_->GetOutputs(); | |||||
| for (auto it = outputs.begin(); it != outputs.end(); ++it) { | |||||
| if (it->second->ElementsNum() == size) return it->second; | |||||
| } | |||||
| std::cout << "Model does not have an output tensor with size " << size << std::endl; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dataset, bool serially) const { | |||||
| std::vector<int> labels_vec; | |||||
| static unsigned int idx = 1; | |||||
| int total_size = dataset.size(); | |||||
| auto inputs = session_->GetInputs(); | |||||
| char *input_data = reinterpret_cast<char *>(inputs.at(data_index_)->MutableData()); | |||||
| auto labels = reinterpret_cast<float *>(inputs.at(label_index_)->MutableData()); | |||||
| MS_ASSERT(total_size > 0); | |||||
| MS_ASSERT(input_data != nullptr); | |||||
| std::fill(labels, labels + inputs.at(label_index_)->ElementsNum(), 0.f); | |||||
| for (int i = 0; i < batch_size_; i++) { | |||||
| if (serially) { | |||||
| idx = ++idx % total_size; | |||||
| } else { | |||||
| idx = rand_r(&seed_) % total_size; | |||||
| } | |||||
| int label = 0; | |||||
| char *data = nullptr; | |||||
| std::tie(data, label) = dataset[idx]; | |||||
| std::memcpy(input_data + i * data_size_, data, data_size_); | |||||
| labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation | |||||
| labels_vec.push_back(label); | |||||
| } | |||||
| return labels_vec; | |||||
| } | |||||
| float NetRunner::CalculateAccuracy(int max_tests) const { | |||||
| float accuracy = 0.0; | |||||
| const std::vector<DataLabelTuple> test_set = ds_.test_data(); | |||||
| int tests = test_set.size() / batch_size_; | |||||
| if (max_tests != -1 && tests < max_tests) tests = max_tests; | |||||
| session_->Eval(); | |||||
| for (int i = 0; i < tests; i++) { | |||||
| auto labels = FillInputData(test_set, (max_tests == -1)); | |||||
| session_->RunGraph(); | |||||
| auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_); | |||||
| MS_ASSERT(outputsv != nullptr); | |||||
| auto scores = reinterpret_cast<float *>(outputsv->MutableData()); | |||||
| for (int b = 0; b < batch_size_; b++) { | |||||
| int max_idx = 0; | |||||
| float max_score = scores[num_of_classes_ * b]; | |||||
| for (int c = 0; c < num_of_classes_; c++) { | |||||
| if (scores[num_of_classes_ * b + c] > max_score) { | |||||
| max_score = scores[num_of_classes_ * b + c]; | |||||
| max_idx = c; | |||||
| } | |||||
| } | |||||
| if (labels[b] == max_idx) accuracy += 1.0; | |||||
| } | |||||
| } | |||||
| session_->Train(); | |||||
| accuracy /= static_cast<float>(batch_size_ * tests); | |||||
| return accuracy; | |||||
| float NetRunner::CalculateAccuracy(int max_tests) { | |||||
| AccuracyMonitor test_am(&ds_, 1, max_tests); | |||||
| test_am.EpochEnd(TrainLoopCallBackData(true, 0, session_, loop_)); | |||||
| return 0.0; | |||||
| } | } | ||||
| int NetRunner::InitDB() { | int NetRunner::InitDB() { | ||||
| @@ -155,35 +108,17 @@ int NetRunner::InitDB() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| float NetRunner::GetLoss() const { | |||||
| auto outputsv = SearchOutputsForSize(1); // Search for Loss which is a single value tensor | |||||
| MS_ASSERT(outputsv != nullptr); | |||||
| auto loss = reinterpret_cast<float *>(outputsv->MutableData()); | |||||
| return loss[0]; | |||||
| } | |||||
| int NetRunner::TrainLoop() { | int NetRunner::TrainLoop() { | ||||
| session_->Train(); | |||||
| float min_loss = 1000.; | |||||
| float max_acc = 0.; | |||||
| for (int i = 0; i < cycles_; i++) { | |||||
| FillInputData(ds_.train_data()); | |||||
| session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr); | |||||
| float loss = GetLoss(); | |||||
| if (min_loss > loss) min_loss = loss; | |||||
| if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) { | |||||
| auto cpkt_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms"; | |||||
| session_->SaveToFile(cpkt_fn); | |||||
| } | |||||
| struct mindspore::lite::StepLRLambda step_lr_lambda(100, 0.9); | |||||
| mindspore::lite::LRScheduler step_lr_sched(mindspore::lite::StepLRLambda, static_cast<void *>(&step_lr_lambda), 100); | |||||
| if ((i + 1) % 100 == 0) { | |||||
| float acc = CalculateAccuracy(10); | |||||
| if (max_acc < acc) max_acc = acc; | |||||
| std::cout << i + 1 << ":\tLoss is " << std::setw(7) << loss << " [min=" << min_loss << "] " | |||||
| << " max_acc=" << max_acc << std::endl; | |||||
| } | |||||
| } | |||||
| mindspore::lite::LossMonitor lm(100); | |||||
| // mindspore::lite::ClassificationTrainAccuracyMonitor am(10); | |||||
| mindspore::lite::CkptSaver cs(1000, std::string("lenet")); | |||||
| AccuracyMonitor test_am(&ds_, 500, 10); | |||||
| DataLoader dl(&ds_); | |||||
| loop_->Train(cycles_, std::vector<TrainLoopCallBack *>{&dl, &lm, &test_am, &cs, &step_lr_sched}); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -194,8 +129,7 @@ int NetRunner::Main() { | |||||
| TrainLoop(); | TrainLoop(); | ||||
| float acc = CalculateAccuracy(); | |||||
| std::cout << "accuracy = " << acc << std::endl; | |||||
| CalculateAccuracy(); | |||||
| if (cycles_ > 0) { | if (cycles_ > 0) { | ||||
| auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(cycles_) + ".ms"; | auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(cycles_) + ".ms"; | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include "include/train_session.h" | #include "include/train_session.h" | ||||
| #include "include/train/train_loop.h" | |||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| #include "src/dataset.h" | #include "src/dataset.h" | ||||
| @@ -38,12 +39,13 @@ class NetRunner { | |||||
| int InitDB(); | int InitDB(); | ||||
| int TrainLoop(); | int TrainLoop(); | ||||
| std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, bool is_train_set = false) const; | std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, bool is_train_set = false) const; | ||||
| float CalculateAccuracy(int max_tests = -1) const; | |||||
| float CalculateAccuracy(int max_tests = -1); | |||||
| float GetLoss() const; | float GetLoss() const; | ||||
| mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const; | mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const; | ||||
| DataSet ds_; | DataSet ds_; | ||||
| mindspore::session::TrainSession *session_ = nullptr; | mindspore::session::TrainSession *session_ = nullptr; | ||||
| mindspore::session::TrainLoop *loop_ = nullptr; | |||||
| std::string ms_file_ = ""; | std::string ms_file_ = ""; | ||||
| std::string data_dir_ = ""; | std::string data_dir_ = ""; | ||||
| @@ -17,9 +17,10 @@ | |||||
| #include "src/net_runner.h" | #include "src/net_runner.h" | ||||
| #include <math.h> | #include <math.h> | ||||
| #include <getopt.h> | #include <getopt.h> | ||||
| #include <algorithm> | |||||
| #include <cstring> | #include <cstring> | ||||
| #include <iostream> | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | |||||
| #include "include/context.h" | #include "include/context.h" | ||||
| #include "src/utils.h" | #include "src/utils.h" | ||||
| @@ -113,7 +114,7 @@ std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dat | |||||
| int label = 0; | int label = 0; | ||||
| char *data = nullptr; | char *data = nullptr; | ||||
| std::tie(data, label) = dataset[idx]; | std::tie(data, label) = dataset[idx]; | ||||
| std::memcpy(input_data + i * data_size_, data, data_size_); | |||||
| std::copy(data, data + data_size, input_data + i * data_size); | |||||
| labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation | labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation | ||||
| labels_vec.push_back(label); | labels_vec.push_back(label); | ||||
| } | } | ||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ | |||||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ | |||||
| #include <stdio.h> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| #include "include/train/train_loop.h" | |||||
| using GraphPoint = std::pair<int, float>; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class CkptSaver : public session::TrainLoopCallBack { | |||||
| public: | |||||
| CkptSaver(int save_every_n, const std::string &filename_prefix) | |||||
| : save_every_n_(save_every_n), filename_prefix_(filename_prefix) {} | |||||
| int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { | |||||
| if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { | |||||
| auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms"; | |||||
| remove(cpkt_fn.c_str()); | |||||
| cb_data.session_->SaveToFile(cpkt_fn); | |||||
| } | |||||
| return session::RET_CONTINUE; | |||||
| } | |||||
| private: | |||||
| int save_every_n_; | |||||
| std::string filename_prefix_; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_ | |||||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <climits> | |||||
| #include <unordered_map> | |||||
| #include "include/train/train_loop.h" | |||||
| using GraphPoint = std::pair<int, float>; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack { | |||||
| public: | |||||
| explicit ClassificationTrainAccuracyMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {} | |||||
| virtual ~ClassificationTrainAccuracyMonitor() = default; | |||||
| void Begin(const session::TrainLoopCallBackData &cb_data) override; | |||||
| void EpochBegin(const session::TrainLoopCallBackData &cb_data) override; | |||||
| int EpochEnd(const session::TrainLoopCallBackData &cb_data) override; | |||||
| void StepEnd(const session::TrainLoopCallBackData &cb_data) override; | |||||
| const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; } | |||||
| private: | |||||
| std::vector<GraphPoint> accuracies_; | |||||
| int print_every_n_ = 0; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_ | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_ | |||||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <climits> | |||||
| #include <unordered_map> | |||||
| #include "include/train/train_loop_callback.h" | |||||
| using GraphPoint = std::pair<int, float>; | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class LossMonitor : public session::TrainLoopCallBack { | |||||
| public: | |||||
| explicit LossMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {} | |||||
| virtual ~LossMonitor() = default; | |||||
| void Begin(const session::TrainLoopCallBackData &cb_data) override; | |||||
| void EpochBegin(const session::TrainLoopCallBackData &cb_data) override; | |||||
| int EpochEnd(const session::TrainLoopCallBackData &cb_data) override; | |||||
| void StepEnd(const session::TrainLoopCallBackData &cb_data) override; | |||||
| const std::vector<GraphPoint> &GetLossPoints() const { return losses_; } | |||||
| private: | |||||
| std::vector<GraphPoint> losses_; | |||||
| int print_every_n_; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_ | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_ | |||||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <functional> | |||||
| #include <unordered_map> | |||||
| #include "include/train/train_loop_callback.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| constexpr int DONT_UPDATE_LR = 0; | |||||
| constexpr int UPDATE_LR = 1; | |||||
| using LR_Lambda = std::function<int(float *lr, int epoch, void *cb_data)>; | |||||
| /// \brief Multiply the LR by a factor of gamma every epoch | |||||
| int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication); | |||||
| /// \brief Multiply the LR by a factor of gamma every step_size | |||||
| int StepLRLambda(float *lr, int epoch, void *step_size); | |||||
| struct StepLRLambda { | |||||
| StepLRLambda(int step, float g) : step_size(step), gamma(g) {} | |||||
| int step_size; // period of LR decay | |||||
| float gamma; // LR decay factor | |||||
| }; | |||||
| class LRScheduler : public session::TrainLoopCallBack { | |||||
| public: | |||||
| explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step_ = 1); | |||||
| virtual ~LRScheduler() = default; | |||||
| int EpochEnd(const session::TrainLoopCallBackData &cb_data) override; | |||||
| private: | |||||
| LR_Lambda lambda_func_; | |||||
| void *lr_data_ = nullptr; | |||||
| int step_ = 1; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_ | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ | |||||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <unordered_map> | |||||
| #include "include/train/train_loop_callback.h" | |||||
| #include "include/train_session.h" | |||||
| namespace mindspore { | |||||
| namespace session { | |||||
| class TrainLoop { | |||||
| public: | |||||
| /// \brief Static method to create a TrainLoop object | |||||
| /// | |||||
| /// \param[in] filename Filename to read flatbuffer from | |||||
| /// \param[in] context Defines the context of the session to be created | |||||
| /// | |||||
| /// \return Pointer of MindSpore Lite TrainLoop | |||||
| static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1); | |||||
| /// \brief Class destructor | |||||
| virtual ~TrainLoop() = default; | |||||
| /// \brief Resets the epoch counter | |||||
| /// | |||||
| /// \return 0 on success or -1 in case of error | |||||
| virtual int Reset() = 0; // resets the epoch counter to 0. | |||||
| /// \brief Accessor to the TrainSession | |||||
| /// | |||||
| /// \return pointer of the train_session | |||||
| virtual session::TrainSession *train_session() = 0; | |||||
| /// \brief Accessor to the Session KernelCallbacks | |||||
| /// | |||||
| /// \param[in] before Define a call_back_function to be called before running each node. | |||||
| /// \param[in] after Define a call_back_function called after running each node. | |||||
| /// | |||||
| /// \return 0 on success or -1 in case of error | |||||
| virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0; | |||||
| /// \brief Performs the training Loop | |||||
| /// | |||||
| /// \param[in] epoch The number of epochs to run | |||||
| /// \param[in] cbs A vector of TrainLoopCallBack objects | |||||
| /// | |||||
| /// \return 0 on success or -1 in case of error | |||||
| virtual int Train(int epochs, std::vector<TrainLoopCallBack *> cbs) = 0; | |||||
| }; | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_ | |||||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <unordered_map> | |||||
| namespace mindspore { | |||||
| namespace session { | |||||
| class TrainSession; | |||||
| class TrainLoop; | |||||
| struct TrainLoopCallBackData { | |||||
| TrainLoopCallBackData(bool train_mode, int epoch, TrainSession *session, TrainLoop *loop) | |||||
| : train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {} | |||||
| bool train_mode_; /**< training mode of TrainSession object */ | |||||
| unsigned int epoch_; /**< the current training epoch (starts at 0) */ | |||||
| unsigned int step_ = 0; /**< the current step within the epoch */ | |||||
| TrainSession *session_; /**< pointer to the TrainSession */ | |||||
| TrainLoop *loop_; | |||||
| }; | |||||
| constexpr int RET_CONTINUE = 0; | |||||
| constexpr int RET_STOP_TRAINING = 1; | |||||
| constexpr int RET_EXIT = 2; | |||||
| class TrainLoopCallBack { | |||||
| public: | |||||
| virtual ~TrainLoopCallBack() = default; | |||||
| virtual void Begin(const TrainLoopCallBackData &cb_data) {} | |||||
| virtual void End(const TrainLoopCallBackData &cb_data) {} | |||||
| virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {} | |||||
| virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; } | |||||
| virtual void StepBegin(const TrainLoopCallBackData &cb_data) {} | |||||
| virtual void StepEnd(const TrainLoopCallBackData &cb_data) {} | |||||
| }; | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_ | |||||
| @@ -83,6 +83,23 @@ class TrainSession : public session::LiteSession { | |||||
| /// \return boolean indication if model is in eval mode | /// \return boolean indication if model is in eval mode | ||||
| bool IsEval() { return train_mode_ == false; } | bool IsEval() { return train_mode_ == false; } | ||||
| /// \brief Sets the Learning Rate of the training | |||||
| /// | |||||
| /// \param[in] learning_rate to set | |||||
| /// | |||||
| /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h | |||||
| virtual int SetLearningRate(float learning_rate) = 0; | |||||
| /// \brief Gets the Learning Rate of the training | |||||
| /// | |||||
| /// \return learning rate. 0.0 if no optimizer was found | |||||
| virtual float GetLearningRate() = 0; | |||||
| /// \brief Get output MindSpore Lite MSTensors of Training model prediction | |||||
| /// | |||||
| /// \return The map of output tensor name and MindSpore Lite MSTensor. | |||||
| virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetPredictions() const = 0; | |||||
| protected: | protected: | ||||
| bool train_mode_ = false; | bool train_mode_ = false; | ||||
| }; | }; | ||||
| @@ -0,0 +1,178 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| package com.mindspore.lite; | |||||
| import java.util.ArrayList; | |||||
| import java.util.HashMap; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| import java.util.Set; | |||||
| import com.mindspore.lite.config.MSConfig; | |||||
| public class TrainSession { | |||||
| static { | |||||
| System.loadLibrary("mindspore-lite-jni"); | |||||
| } | |||||
| private long sessionPtr; | |||||
| public TrainSession() { | |||||
| this.sessionPtr = 0; | |||||
| } | |||||
| public boolean init(String modelFilename, MSConfig config) { | |||||
| this.sessionPtr = createSession(modelFilename, config.getMSConfigPtr()); | |||||
| return this.sessionPtr != 0; | |||||
| } | |||||
| public long getSessionPtr() { | |||||
| return sessionPtr; | |||||
| } | |||||
| public void bindThread(boolean if_bind) { | |||||
| this.bindThread(this.sessionPtr, if_bind); | |||||
| } | |||||
| public boolean runGraph() { | |||||
| return this.runGraph(this.sessionPtr); | |||||
| } | |||||
| public List<MSTensor> getInputs() { | |||||
| List<Long> ret = this.getInputs(this.sessionPtr); | |||||
| ArrayList<MSTensor> tensors = new ArrayList<MSTensor>(); | |||||
| for (Long ms_tensor_addr : ret) { | |||||
| MSTensor msTensor = new MSTensor(ms_tensor_addr); | |||||
| tensors.add(msTensor); | |||||
| } | |||||
| return tensors; | |||||
| } | |||||
| public MSTensor getInputsByTensorName(String tensorName) { | |||||
| Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName); | |||||
| if(tensor_addr == null){ | |||||
| return null; | |||||
| } | |||||
| MSTensor msTensor = new MSTensor(tensor_addr); | |||||
| return msTensor; | |||||
| } | |||||
| public List<MSTensor> getOutputsByNodeName(String nodeName) { | |||||
| List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName); | |||||
| ArrayList<MSTensor> tensors = new ArrayList<>(); | |||||
| for (Long msTensorAddr : ret) { | |||||
| MSTensor msTensor = new MSTensor(msTensorAddr); | |||||
| tensors.add(msTensor); | |||||
| } | |||||
| return tensors; | |||||
| } | |||||
| public Map<String, MSTensor> getOutputMapByTensor() { | |||||
| Map<String, Long> ret = this.getOutputMapByTensor(this.sessionPtr); | |||||
| Map<String, MSTensor> tensorMap = new HashMap<>(); | |||||
| Set<Map.Entry<String, Long>> entrySet = ret.entrySet(); | |||||
| for (Map.Entry<String, Long> entry : entrySet) { | |||||
| String name = entry.getKey(); | |||||
| Long msTensorAddr = entry.getValue(); | |||||
| tensorMap.put(name, new MSTensor(msTensorAddr)); | |||||
| } | |||||
| return tensorMap; | |||||
| } | |||||
| public List<String> getOutputTensorNames() { | |||||
| return getOutputTensorNames(this.sessionPtr); | |||||
| } | |||||
| public MSTensor getOutputByTensorName(String tensorName) { | |||||
| Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName); | |||||
| if(tensor_addr == null){ | |||||
| return null; | |||||
| } | |||||
| return new MSTensor(tensor_addr); | |||||
| } | |||||
| public void free() { | |||||
| this.free(this.sessionPtr); | |||||
| this.sessionPtr = 0; | |||||
| } | |||||
| public boolean resize(List<MSTensor> inputs, int[][] dims) { | |||||
| long[] inputs_array = new long[inputs.size()]; | |||||
| for (int i = 0; i < inputs.size(); i++) { | |||||
| inputs_array[i] = inputs.get(i).getMSTensorPtr(); | |||||
| } | |||||
| return this.resize(this.sessionPtr, inputs_array, dims); | |||||
| } | |||||
| public boolean saveToFile(String modelFilename) { | |||||
| return this.saveToFile(this.sessionPtr, modelFilename); | |||||
| } | |||||
| public boolean train() { | |||||
| return this.train(this.sessionPtr); | |||||
| } | |||||
| public boolean eval() { | |||||
| return this.eval(this.sessionPtr); | |||||
| } | |||||
| public boolean isTrain() { | |||||
| return this.isTrain(this.sessionPtr); | |||||
| } | |||||
| public boolean isEval() { | |||||
| return this.isEval(this.sessionPtr); | |||||
| } | |||||
| public boolean setLearningRate(float learning_rate) { | |||||
| return this.setLearningRate(this.sessionPtr, learning_rate); | |||||
| } | |||||
| private native long createSession(String modelFilename, long msConfigPtr); | |||||
| private native void bindThread(long sessionPtr, boolean if_bind); | |||||
| private native boolean runGraph(long sessionPtr); | |||||
| private native List<Long> getInputs(long sessionPtr); | |||||
| private native long getInputsByTensorName(long sessionPtr, String tensorName); | |||||
| private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName); | |||||
| private native Map<String, Long> getOutputMapByTensor(long sessionPtr); | |||||
| private native List<String> getOutputTensorNames(long sessionPtr); | |||||
| private native long getOutputByTensorName(long sessionPtr, String tensorName); | |||||
| private native void free(long sessionPtr); | |||||
| private native boolean resize(long sessionPtr, long[] inputs, int[][] dims); | |||||
| private native boolean saveToFile(long sessionPtr, String modelFilename); | |||||
| private native boolean train(long sessionPtr); | |||||
| private native boolean eval(long sessionPtr); | |||||
| private native boolean isTrain(long sessionPtr); | |||||
| private native boolean isEval(long sessionPtr); | |||||
| private native boolean setLearningRate(long sessionPtr, float learning_rate); | |||||
| } | |||||
| @@ -7,8 +7,10 @@ set(PLATFORM_ARM "on") | |||||
| set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR}) | set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR}) | ||||
| set(MS_VERSION_MINOR ${MS_VERSION_MINOR}) | set(MS_VERSION_MINOR ${MS_VERSION_MINOR}) | ||||
| set(MS_VERSION_REVISION ${MS_VERSION_REVISION}) | set(MS_VERSION_REVISION ${MS_VERSION_REVISION}) | ||||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") | |||||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \ | |||||
| -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \ | |||||
| -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") | ||||
| #set for cross-compiling toolchain | #set for cross-compiling toolchain | ||||
| @@ -16,16 +18,16 @@ set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) | |||||
| set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) | set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) | ||||
| set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) | set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) | ||||
| if (ENABLE_VERBOSE) | |||||
| if(ENABLE_VERBOSE) | |||||
| set(CMAKE_VERBOSE_MAKEFILE on) | set(CMAKE_VERBOSE_MAKEFILE on) | ||||
| endif () | |||||
| endif() | |||||
| if (PLATFORM_ARM32) | |||||
| if(PLATFORM_ARM32) | |||||
| add_compile_definitions(ENABLE_ARM32) | add_compile_definitions(ENABLE_ARM32) | ||||
| endif () | |||||
| if (PLATFORM_ARM64) | |||||
| endif() | |||||
| if(PLATFORM_ARM64) | |||||
| add_compile_definitions(ENABLE_ARM64) | add_compile_definitions(ENABLE_ARM64) | ||||
| endif () | |||||
| endif() | |||||
| set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../../..) | set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../../..) | ||||
| set(LITE_DIR ${TOP_DIR}/mindspore/lite) | set(LITE_DIR ${TOP_DIR}/mindspore/lite) | ||||
| @@ -40,15 +42,25 @@ include_directories(${LITE_DIR}/build) ## flatbuffers | |||||
| link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}/) | link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}/) | ||||
| add_library(mindspore-lite-jni SHARED | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp | |||||
| ) | |||||
| set(JNI_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp | |||||
| ) | |||||
| if(SUPPORT_TRAIN) | |||||
| set(JNI_SRC | |||||
| ${JNI_SRC} | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp | |||||
| ) | |||||
| endif() | |||||
| add_library(mindspore-lite-jni SHARED ${JNI_SRC}) | |||||
| find_library(log-lib log) | find_library(log-lib log) | ||||
| target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib}) | |||||
| target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib}) | |||||
| @@ -0,0 +1,305 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <jni.h> | |||||
| #include "common/ms_log.h" | |||||
| #include "common/jni_utils.h" | |||||
| #include "include/train_session.h" | |||||
| #include "include/errorcode.h" | |||||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createSession(JNIEnv *env, jobject thiz, | |||||
| jstring model_file_name, | |||||
| jlong ms_config_ptr) { | |||||
| auto *pointer = reinterpret_cast<void *>(ms_config_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Context pointer from java is nullptr"); | |||||
| return jlong(nullptr); | |||||
| } | |||||
| auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer); | |||||
| auto session = mindspore::session::TrainSession::CreateSession(JstringToChar(env, model_file_name), lite_context_ptr); | |||||
| if (session == nullptr) { | |||||
| MS_LOGE("CreateSession failed"); | |||||
| return jlong(nullptr); | |||||
| } | |||||
| return jlong(session); | |||||
| } | |||||
| extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_TrainSession_bindThread(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr, jboolean if_bind) { | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| train_session_ptr->BindThread(if_bind); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_runGraph(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr) { | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| auto ret = train_session_ptr->RunGraph(); | |||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | |||||
| } | |||||
| extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getInputs(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr) { | |||||
| jclass array_list = env->FindClass("java/util/ArrayList"); | |||||
| jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V"); | |||||
| jobject ret = env->NewObject(array_list, array_list_construct); | |||||
| jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); | |||||
| jclass long_object = env->FindClass("java/lang/Long"); | |||||
| jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V"); | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return ret; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| auto inputs = train_session_ptr->GetInputs(); | |||||
| for (auto input : inputs) { | |||||
| jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input)); | |||||
| env->CallBooleanMethod(ret, array_list_add, tensor_addr); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_getInputsByTensorName(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr, | |||||
| jstring tensor_name) { | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return jlong(nullptr); | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| auto input = train_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name)); | |||||
| return jlong(input); | |||||
| } | |||||
| extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputsByNodeName(JNIEnv *env, | |||||
| jobject thiz, | |||||
| jlong session_ptr, | |||||
| jstring node_name) { | |||||
| jclass array_list = env->FindClass("java/util/ArrayList"); | |||||
| jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V"); | |||||
| jobject ret = env->NewObject(array_list, array_list_construct); | |||||
| jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); | |||||
| jclass long_object = env->FindClass("java/lang/Long"); | |||||
| jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V"); | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return ret; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| auto inputs = train_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name)); | |||||
| for (auto input : inputs) { | |||||
| jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input)); | |||||
| env->CallBooleanMethod(ret, array_list_add, tensor_addr); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputMapByTensor(JNIEnv *env, | |||||
| jobject thiz, | |||||
| jlong session_ptr) { | |||||
| jclass hash_map_clazz = env->FindClass("java/util/HashMap"); | |||||
| jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V"); | |||||
| jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct); | |||||
| jmethodID hash_map_put = | |||||
| env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return hash_map; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| auto outputs = train_session_ptr->GetOutputs(); | |||||
| jclass long_object = env->FindClass("java/lang/Long"); | |||||
| jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V"); | |||||
| for (auto output_iter : outputs) { | |||||
| auto node_name = output_iter.first; | |||||
| auto ms_tensor = output_iter.second; | |||||
| jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor)); | |||||
| env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr); | |||||
| } | |||||
| return hash_map; | |||||
| } | |||||
| extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputTensorNames(JNIEnv *env, | |||||
| jobject thiz, | |||||
| jlong session_ptr) { | |||||
| jclass array_list = env->FindClass("java/util/ArrayList"); | |||||
| jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V"); | |||||
| jobject ret = env->NewObject(array_list, array_list_construct); | |||||
| jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return ret; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| auto output_names = train_session_ptr->GetOutputTensorNames(); | |||||
| for (auto output_name : output_names) { | |||||
| env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str())); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_getOutputByTensorName(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr, | |||||
| jstring tensor_name) { | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return jlong(nullptr); | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| auto output = train_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name)); | |||||
| return jlong(output); | |||||
| } | |||||
| extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_TrainSession_free(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr) { | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| delete (train_session_ptr); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_resize(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr, jlongArray inputs, | |||||
| jobjectArray dims) { | |||||
| std::vector<std::vector<int>> c_dims; | |||||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(pointer); | |||||
| jsize input_size = static_cast<int>(env->GetArrayLength(inputs)); | |||||
| jlong *input_data = env->GetLongArrayElements(inputs, nullptr); | |||||
| std::vector<mindspore::tensor::MSTensor *> c_inputs; | |||||
| for (int i = 0; i < input_size; i++) { | |||||
| auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]); | |||||
| if (tensor_pointer == nullptr) { | |||||
| MS_LOGE("Tensor pointer from java is nullptr"); | |||||
| return false; | |||||
| } | |||||
| auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer); | |||||
| c_inputs.push_back(ms_tensor_ptr); | |||||
| } | |||||
| jsize tensor_size = static_cast<int>(env->GetArrayLength(dims)); | |||||
| for (int i = 0; i < tensor_size; i++) { | |||||
| jintArray array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i)); | |||||
| jsize dim_size = static_cast<int>(env->GetArrayLength(array)); | |||||
| jint *dim_data = env->GetIntArrayElements(array, nullptr); | |||||
| std::vector<int> tensor_dims; | |||||
| for (int j = 0; j < dim_size; j++) { | |||||
| tensor_dims.push_back(dim_data[j]); | |||||
| } | |||||
| c_dims.push_back(tensor_dims); | |||||
| } | |||||
| int ret = train_session_ptr->Resize(c_inputs, c_dims); | |||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_saveToFile(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr, | |||||
| jstring model_file_name) { | |||||
| auto *session_pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (session_pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer); | |||||
| auto ret = train_session_ptr->SaveToFile(JstringToChar(env, model_file_name)); | |||||
| return (jboolean)(ret == 0); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_train(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr) { | |||||
| auto *session_pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (session_pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer); | |||||
| auto ret = train_session_ptr->Train(); | |||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_eval(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr) { | |||||
| auto *session_pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (session_pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer); | |||||
| auto ret = train_session_ptr->Eval(); | |||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_isTrain(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr) { | |||||
| auto *session_pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (session_pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer); | |||||
| auto ret = train_session_ptr->IsTrain(); | |||||
| return (jboolean)(ret); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_isEval(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr) { | |||||
| auto *session_pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (session_pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer); | |||||
| auto ret = train_session_ptr->IsEval(); | |||||
| return (jboolean)(ret); | |||||
| } | |||||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLearningRate(JNIEnv *env, jobject thiz, | |||||
| jlong session_ptr, | |||||
| jfloat learning_rate) { | |||||
| auto *session_pointer = reinterpret_cast<void *>(session_ptr); | |||||
| if (session_pointer == nullptr) { | |||||
| MS_LOGE("Session pointer from java is nullptr"); | |||||
| return (jboolean) false; | |||||
| } | |||||
| auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer); | |||||
| auto ret = train_session_ptr->SetLearningRate(learning_rate); | |||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | |||||
| } | |||||
| @@ -800,8 +800,8 @@ int ElementSubRelu6(const float *in0, const float *in1, float *out, int size) { | |||||
| int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, | int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, | ||||
| ArithmeticParameter *param) { | ArithmeticParameter *param) { | ||||
| TileDimensionsFp32(in0, in1, tile_in0, tile_in0, param); | |||||
| return ElementDiv(tile_in0, tile_in0, out, size); | |||||
| TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); | |||||
| return ElementDiv(tile_in0, tile_in1, out, size); | |||||
| } | } | ||||
| int ElementDiv(const float *in0, const float *in1, float *out, int size) { | int ElementDiv(const float *in0, const float *in1, float *out, int size) { | ||||
| @@ -21,32 +21,48 @@ | |||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| inline int ReluGrad(float *src0, float *src1, size_t length, float *dst) { | inline int ReluGrad(float *src0, float *src1, size_t length, float *dst) { | ||||
| for (size_t i = 0; i < length; ++i) { | |||||
| if (src1[i] > 0) { | |||||
| dst[i] = src0[i]; | |||||
| } else { | |||||
| dst[i] = 0; | |||||
| } | |||||
| int i = 0; | |||||
| #ifdef ENABLE_ARM | |||||
| float32x4_t zero_4 = vdupq_n_f32(0.0f); | |||||
| for (; i < length - 4; i += 4) { | |||||
| float32x4_t src1_4 = vld1q_f32(src1 + i); | |||||
| float32x4_t src0_4 = vld1q_f32(src0 + i); | |||||
| uint32x4_t mask_4 = vcgtq_f32(src1_4, zero_4); | |||||
| float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4); | |||||
| vst1q_f32(dst + i, dst_4); | |||||
| } | |||||
| #endif | |||||
| for (; i < length; ++i) { | |||||
| dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int Relu6Grad(float *src0, float *src1, size_t length, float *dst) { | int Relu6Grad(float *src0, float *src1, size_t length, float *dst) { | ||||
| for (size_t i = 0; i < length; ++i) { | |||||
| if (src1[i] > 0.0f && src1[i] <= 6.0f) { | |||||
| dst[i] = src0[i]; | |||||
| } else { | |||||
| dst[i] = 0.0f; | |||||
| } | |||||
| int i = 0; | |||||
| #ifdef ENABLE_ARM | |||||
| float32x4_t zero_4 = vdupq_n_f32(0.0f); | |||||
| float32x4_t six_4 = vdupq_n_f32(6.0f); | |||||
| for (; i < length - 4; i += 4) { | |||||
| float32x4_t src1_4 = vld1q_f32(src1 + i); | |||||
| float32x4_t src0_4 = vld1q_f32(src0 + i); | |||||
| float32x4_t max_4 = vmaxq_f32(src1_4, zero_4); | |||||
| float32x4_t min_max_4 = vminq_f32(max_4, six_4); | |||||
| uint32x4_t mask_4 = vceqq_f32(min_max_4, src1_4); | |||||
| float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4); | |||||
| vst1q_f32(dst + i, dst_4); | |||||
| } | |||||
| #endif | |||||
| for (; i < length; ++i) { | |||||
| dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) { | int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) { | ||||
| for (size_t i = 0; i < length; ++i) { | for (size_t i = 0; i < length; ++i) { | ||||
| dst[i] = src1[i] > 0.0f ? 1.0f : alpha; | |||||
| dst[i] = src1[i] > 0.0f ? src0[i] : alpha * src0[i]; | |||||
| } | } | ||||
| ElementMul(src0, dst, dst, length); | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -17,55 +17,36 @@ | |||||
| #include <string.h> | #include <string.h> | ||||
| #include "nnacl/fp32_grad/batch_norm.h" | #include "nnacl/fp32_grad/batch_norm.h" | ||||
| void sumSpatialBatch(const float *in, size_t size, int ch, float *out) { | |||||
| memset(out, 0, ch * sizeof(float)); | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| const float *ptr = in + (i * ch); | |||||
| for (size_t c = 0; c < ch; c++) { | |||||
| out[c] += ptr[c]; | |||||
| } | |||||
| void var2Invar(float *save_var, int size, float eps) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| save_var[i] = 1.0f / sqrt(save_var[i] + eps); | |||||
| } | } | ||||
| } | } | ||||
| void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean, | |||||
| float *invar, float *dxhathat_sum, float *dxhat_sum, float *out) { | |||||
| const float N = (size); | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| for (size_t f = 0; f < channels; f++) { | |||||
| size_t ix = i * channels + f; | |||||
| float x_hat = (in[ix] - mean[f]) * invar[f]; | |||||
| float dx_hat = dout[ix] * scale[f]; | |||||
| dxhat_sum[f] += dx_hat; | |||||
| dxhathat_sum[f] += dx_hat * x_hat; | |||||
| void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean, | |||||
| const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum, | |||||
| float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale, float *restrict dx) { | |||||
| float N = (float)size; | |||||
| for (int i = 0; i < size; i++) { | |||||
| for (int c = 0; c < ch; c++) { | |||||
| int ix = i * ch + c; | |||||
| dbias[c] += yt[ix]; | |||||
| // dscale | |||||
| float x_hat = (in[ix] - mean[c]) * invar[c]; | |||||
| dscale[c] += (yt[ix] * x_hat); | |||||
| // dx_1 | |||||
| float dx_hat = yt[ix] * scale[c]; | |||||
| dxhat_sum[c] += dx_hat; | |||||
| dxhathat_sum[c] += dx_hat * x_hat; | |||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < size; i++) { | |||||
| for (size_t f = 0; f < channels; f++) { | |||||
| size_t ix = i * channels + f; | |||||
| float x_hat = (in[ix] - mean[f]) * invar[f]; | |||||
| float dx_hat = dout[ix] * scale[f]; | |||||
| out[ix] = 1.0f / N * (invar[f]) * (N * dx_hat - dxhat_sum[f] - x_hat * dxhathat_sum[f]); | |||||
| } | |||||
| } | |||||
| } | |||||
| void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n, | |||||
| int size, float *scale_updates) { | |||||
| size_t i, b, f; | |||||
| memset(scale_updates, 0, n * sizeof(float)); | |||||
| for (b = 0; b < batch; ++b) { | |||||
| for (i = 0; i < size; ++i) { | |||||
| for (f = 0; f < n; ++f) { | |||||
| int index = (b * size + i) * n + f; | |||||
| float x_norm = (x[index] - mean[f]) * invar[f]; | |||||
| scale_updates[f] += (delta[index] * x_norm); | |||||
| } | |||||
| for (int i = 0; i < size; i++) { | |||||
| for (int c = 0; c < ch; c++) { | |||||
| // dx_2 | |||||
| int ix = i * ch + c; | |||||
| float x_hat = (in[ix] - mean[c]) * invar[c]; | |||||
| float dx_hat = yt[ix] * scale[c]; | |||||
| dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void var2Invar(float *save_var, size_t size, float eps) { | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| save_var[i] = 1.0f / sqrt(save_var[i] + eps); | |||||
| } | |||||
| } | |||||
| @@ -29,13 +29,9 @@ typedef struct BNGradParameter { | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void sumSpatialBatch(const float *in, size_t size, int ch, float *out); | |||||
| void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean, | |||||
| float *invar, float *xhat_sum, float *dxhat_sum, float *out); | |||||
| void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n, | |||||
| int size, float *scale_updates); | |||||
| void var2Invar(float *save_var, size_t size, float eps); | |||||
| void var2Invar(float *save_var, int size, float eps); | |||||
| void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, | |||||
| int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -20,7 +20,7 @@ | |||||
| int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, | int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, | ||||
| const float *weight, const float *dloss, float *dx) { | const float *weight, const float *dloss, float *dx) { | ||||
| const float epsilon = 1e-12; | |||||
| const float epsilon = 1e-12f; | |||||
| if (reduction == 0) { | if (reduction == 0) { | ||||
| for (int i = 0; i < input_size; i++) { | for (int i = 0; i < input_size; i++) { | ||||
| float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); | float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); | ||||
| @@ -21,7 +21,7 @@ | |||||
| #endif | #endif | ||||
| #include "nnacl/fp32/matmul_fp32.h" | #include "nnacl/fp32/matmul_fp32.h" | ||||
| static void addv(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) { | |||||
| void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) { | |||||
| const float *src_ptr = v1; | const float *src_ptr = v1; | ||||
| float *dst_ptr = v2; | float *dst_ptr = v2; | ||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| @@ -86,7 +86,8 @@ static void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int | |||||
| return; | return; | ||||
| } | } | ||||
| static void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { | |||||
| static void RowMajor2Col12MajorStride(const float *restrict src_ptr, float *restrict dst_ptr, size_t row, size_t col, | |||||
| int lead) { | |||||
| size_t row_up_12 = UP_ROUND(row, C12NUM); | size_t row_up_12 = UP_ROUND(row, C12NUM); | ||||
| size_t row12 = row / C12NUM * C12NUM; | size_t row12 = row / C12NUM * C12NUM; | ||||
| size_t col4 = col / C4NUM * C4NUM; | size_t col4 = col / C4NUM * C4NUM; | ||||
| @@ -549,7 +550,7 @@ void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const floa | |||||
| #else | #else | ||||
| MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); | MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); | ||||
| #endif | #endif | ||||
| if (incremental) addv(output, mat_c, beta, M, N, ldc); | |||||
| if (incremental) AddMatrix(output, mat_c, beta, M, N, ldc); | |||||
| gcb->mat_a = mat_a_input; | gcb->mat_a = mat_a_input; | ||||
| gcb->mat_b = mat_b_input; | gcb->mat_b = mat_b_input; | ||||
| } | } | ||||
| @@ -37,6 +37,7 @@ void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *m | |||||
| int ldb, float beta, float *mat_c, int ldc, float *workspace); | int ldb, float beta, float *mat_c, int ldc, float *workspace); | ||||
| int MatSize(int row, int col, int round); | int MatSize(int row, int col, int round); | ||||
| int MatSizeTotal(int row, int col, int deep, int inc); | int MatSizeTotal(int row, int col, int deep, int inc); | ||||
| void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -18,9 +18,8 @@ | |||||
| #include "nnacl/fp32_grad/pack_ext.h" | #include "nnacl/fp32_grad/pack_ext.h" | ||||
| #include "nnacl/pack.h" | #include "nnacl/pack.h" | ||||
| static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } | |||||
| void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start) { | |||||
| void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, | |||||
| int start) { | |||||
| const int pad_left = conv_param->pad_l_; | const int pad_left = conv_param->pad_l_; | ||||
| const int pad_up = conv_param->pad_u_; | const int pad_up = conv_param->pad_u_; | ||||
| @@ -43,22 +42,43 @@ void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParamet | |||||
| int kernel_row, kernel_col; | int kernel_row, kernel_col; | ||||
| for (int i = 0; i < rows; i++) { | |||||
| int block_start = start + i; | |||||
| int input_h = block_start / output_w * stride_h; | |||||
| int input_w = block_start % output_w * stride_w; | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h + input_h; | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_col = -pad_left + kernel_col * dilation_w + input_w; | |||||
| if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels; | |||||
| memcpy(data_col, in_data + offset, sizeof(float) * channels); | |||||
| data_col += channels; | |||||
| } else { | |||||
| memset(data_col, 0, sizeof(float) * channels); | |||||
| data_col += channels; | |||||
| if (channels == 1) { | |||||
| for (int i = 0; i < real_cal_num; i++) { | |||||
| int block_start = start + i; | |||||
| int input_h = block_start / output_w * stride_h; | |||||
| int input_w = block_start % output_w * stride_w; | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h + input_h; | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_col = -pad_left + kernel_col * dilation_w + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels; | |||||
| *data_col = in_data[offset]; | |||||
| data_col++; | |||||
| } else { | |||||
| *data_col = 0; | |||||
| data_col++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < real_cal_num; i++) { | |||||
| int block_start = start + i; | |||||
| int input_h = block_start / output_w * stride_h; | |||||
| int input_w = block_start % output_w * stride_w; | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h + input_h; | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_col = -pad_left + kernel_col * dilation_w + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels; | |||||
| memcpy(data_col, in_data + offset, sizeof(float) * channels); | |||||
| data_col += channels; | |||||
| } else { | |||||
| memset(data_col, 0, sizeof(float) * channels); | |||||
| data_col += channels; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -70,7 +90,6 @@ void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *con | |||||
| rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); | rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); | ||||
| } | } | ||||
| // output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w) | |||||
| void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) { | void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) { | ||||
| const int pad_left = conv_param->pad_l_; | const int pad_left = conv_param->pad_l_; | ||||
| const int pad_up = conv_param->pad_u_; | const int pad_up = conv_param->pad_u_; | ||||
| @@ -100,14 +119,14 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | ||||
| int input_row = -pad_up + kernel_row * dilation_h; | int input_row = -pad_up + kernel_row * dilation_h; | ||||
| for (output_rows = output_h; output_rows; output_rows--) { | for (output_rows = output_h; output_rows; output_rows--) { | ||||
| if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { | |||||
| if (!((unsigned)(input_row) < (unsigned)(in_height))) { | |||||
| for (output_col = output_w; output_col; output_col--) { | for (output_col = output_w; output_col; output_col--) { | ||||
| *(data_row++) = 0; | *(data_row++) = 0; | ||||
| } | } | ||||
| } else { | } else { | ||||
| int input_col = -pad_left + kernel_col * dilation_w; | int input_col = -pad_left + kernel_col * dilation_w; | ||||
| for (output_col = output_w; output_col; output_col--) { | for (output_col = output_w; output_col; output_col--) { | ||||
| if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| if (((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | const int offset = (input_row * in_width + input_col) * tot_channels + channel; | ||||
| *(data_row++) = in_data[offset]; | *(data_row++) = in_data[offset]; | ||||
| } else { | } else { | ||||
| @@ -127,14 +146,14 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv | |||||
| for (channel = 0; channel < channels; channel++) { | for (channel = 0; channel < channels; channel++) { | ||||
| int input_row = -pad_up + kernel_row * dilation_h; | int input_row = -pad_up + kernel_row * dilation_h; | ||||
| for (output_rows = output_h; output_rows; output_rows--) { | for (output_rows = output_h; output_rows; output_rows--) { | ||||
| if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { | |||||
| if (!((unsigned)(input_row) < (unsigned)(in_height))) { | |||||
| for (output_col = output_w; output_col; output_col--) { | for (output_col = output_w; output_col; output_col--) { | ||||
| *(data_row++) = 0; | *(data_row++) = 0; | ||||
| } | } | ||||
| } else { | } else { | ||||
| int input_col = -pad_left + kernel_col * dilation_w; | int input_col = -pad_left + kernel_col * dilation_w; | ||||
| for (output_col = output_w; output_col; output_col--) { | for (output_col = output_w; output_col; output_col--) { | ||||
| if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| if (((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | const int offset = (input_row * in_width + input_col) * tot_channels + channel; | ||||
| *(data_row++) = in_data[offset]; | *(data_row++) = in_data[offset]; | ||||
| } else { | } else { | ||||
| @@ -150,7 +169,6 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { | void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { | ||||
| const int pad_left = conv_param->pad_l_; | const int pad_left = conv_param->pad_l_; | ||||
| const int pad_up = conv_param->pad_u_; | const int pad_up = conv_param->pad_u_; | ||||
| @@ -177,14 +195,14 @@ void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParamet | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | ||||
| for (output_rows = start; output_rows < start + rows; output_rows++) { | for (output_rows = start; output_rows < start + rows; output_rows++) { | ||||
| int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h; | int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h; | ||||
| if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { | |||||
| if (!((unsigned)(input_row) < (unsigned)(in_height))) { | |||||
| for (output_col = output_w; output_col; output_col--) { | for (output_col = output_w; output_col; output_col--) { | ||||
| *(data_row++) = 0; | *(data_row++) = 0; | ||||
| } | } | ||||
| } else { | } else { | ||||
| int input_col = -pad_left + kernel_col * dilation_w; | int input_col = -pad_left + kernel_col * dilation_w; | ||||
| for (output_col = output_w; output_col; output_col--) { | for (output_col = output_w; output_col; output_col--) { | ||||
| if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| if (((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | const int offset = (input_row * in_width + input_col) * tot_channels + channel; | ||||
| *(data_row++) = in_data[offset]; | *(data_row++) = in_data[offset]; | ||||
| } else { | } else { | ||||
| @@ -193,7 +211,6 @@ void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParamet | |||||
| input_col += stride_w; | input_col += stride_w; | ||||
| } | } | ||||
| } | } | ||||
| // input_row += stride_h; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -232,8 +249,7 @@ void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv | |||||
| int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; | int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; | ||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | ||||
| int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; | int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; | ||||
| if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| int offset = (input_row * in_width + input_col) * tot_channels; | int offset = (input_row * in_width + input_col) * tot_channels; | ||||
| float *data_im_ptr = &data_im[offset]; | float *data_im_ptr = &data_im[offset]; | ||||
| for (int i = 0; i < channels; i++) { | for (int i = 0; i < channels; i++) { | ||||
| @@ -271,20 +287,36 @@ void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParamet | |||||
| int kernel_row, kernel_col; | int kernel_row, kernel_col; | ||||
| for (int r = 0; r < rows; r++) { | |||||
| int output_col = (start + r) % output_w; | |||||
| int output_row = (start + r) / output_w; | |||||
| int row_stride_offset = output_row * stride_h; | |||||
| int col_stride_offset = output_col * stride_w; | |||||
| // for (output_col = 0; output_col < output_w; output_col++) | |||||
| { | |||||
| if (channels == 1) { | |||||
| for (int r = 0; r < rows; r++) { | |||||
| int output_col = (start + r) % output_w; | |||||
| int output_row = (start + r) / output_w; | |||||
| int row_stride_offset = output_row * stride_h; | |||||
| int col_stride_offset = output_col * stride_w; | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | ||||
| int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; | int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; | ||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | ||||
| int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; | int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; | ||||
| if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| int offset = (input_row * in_width + input_col) * tot_channels; | |||||
| float *data_im_ptr = &data_im[offset]; | |||||
| *data_im_ptr += *data_col; | |||||
| } | |||||
| data_col++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (int r = 0; r < rows; r++) { | |||||
| int output_col = (start + r) % output_w; | |||||
| int output_row = (start + r) / output_w; | |||||
| int row_stride_offset = output_row * stride_h; | |||||
| int col_stride_offset = output_col * stride_w; | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| int offset = (input_row * in_width + input_col) * tot_channels; | int offset = (input_row * in_width + input_col) * tot_channels; | ||||
| float *data_im_ptr = &data_im[offset]; | float *data_im_ptr = &data_im[offset]; | ||||
| for (int i = 0; i < channels; i++) { | for (int i = 0; i < channels; i++) { | ||||
| @@ -308,7 +308,7 @@ table SoftmaxCrossEntropy { | |||||
| } | } | ||||
| table SparseSoftmaxCrossEntropy { | table SparseSoftmaxCrossEntropy { | ||||
| isGrad: int; | |||||
| isGrad: bool; | |||||
| } | } | ||||
| table make_tuple { | table make_tuple { | ||||
| @@ -1225,11 +1225,9 @@ table SmoothL1LossGrad { | |||||
| } | } | ||||
| table SigmoidCrossEntropyWithLogits { | table SigmoidCrossEntropyWithLogits { | ||||
| beta : float; | |||||
| } | } | ||||
| table SigmoidCrossEntropyWithLogitsGrad { | table SigmoidCrossEntropyWithLogitsGrad { | ||||
| beta : float; | |||||
| } | } | ||||
| table Reciprocal { | table Reciprocal { | ||||
| @@ -65,7 +65,10 @@ if(SUPPORT_TRAIN) | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc | ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc | ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc | ${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/lr_scheduler.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| @@ -19,6 +19,9 @@ | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| #endif | #endif | ||||
| #ifdef SUPPORT_TRAIN | |||||
| #include <tuple> | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -53,12 +56,20 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||||
| } | } | ||||
| string paddingmode = "REFLECT"; | string paddingmode = "REFLECT"; | ||||
| if (prim.GetAttr("mode") == nullptr) { | if (prim.GetAttr("mode") == nullptr) { | ||||
| MS_LOG(ERROR) << "get mode failed!"; | |||||
| delete this->primitive_; | |||||
| delete attr; | |||||
| this->primitive_ = nullptr; | |||||
| attr = nullptr; | |||||
| return RET_ERROR; | |||||
| #ifdef SUPPORT_TRAIN | |||||
| if (prim.name() == "Pad") { | |||||
| paddingmode = "CONSTANT"; | |||||
| } else { | |||||
| #endif | |||||
| MS_LOG(ERROR) << "get mode failed!"; | |||||
| delete this->primitive_; | |||||
| delete attr; | |||||
| this->primitive_ = nullptr; | |||||
| attr = nullptr; | |||||
| return RET_ERROR; | |||||
| #ifdef SUPPORT_TRAIN | |||||
| } | |||||
| #endif | |||||
| } else { | } else { | ||||
| paddingmode = GetValue<string>(prim.GetAttr("mode")); | paddingmode = GetValue<string>(prim.GetAttr("mode")); | ||||
| } | } | ||||
| @@ -66,6 +77,21 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||||
| attr->paddingMode = schema::PaddingMode_REFLECT; | attr->paddingMode = schema::PaddingMode_REFLECT; | ||||
| } else if (paddingmode == "SYMMETRIC") { | } else if (paddingmode == "SYMMETRIC") { | ||||
| attr->paddingMode = schema::PaddingMode_SYMMETRIC; | attr->paddingMode = schema::PaddingMode_SYMMETRIC; | ||||
| #ifdef SUPPORT_TRAIN | |||||
| } else if (paddingmode == "CONSTANT") { | |||||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | |||||
| if (prim.GetAttr("paddings") != nullptr) { | |||||
| auto paddings = prim.GetAttr("paddings"); | |||||
| auto str = (*paddings).ToString(); | |||||
| std::replace(str.begin(), str.end(), ',', ' '); | |||||
| std::replace(str.begin(), str.end(), ')', ' '); | |||||
| std::replace(str.begin(), str.end(), '(', ' '); | |||||
| std::stringstream ss(str); | |||||
| for (int i; ss >> i;) { | |||||
| attr->paddings.push_back(i); | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "model type not supported!"; | MS_LOG(ERROR) << "model type not supported!"; | ||||
| delete this->primitive_; | delete this->primitive_; | ||||
| @@ -674,7 +674,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | ||||
| op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | ||||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | ||||
| } else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu")) { | |||||
| } else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu") || | |||||
| (op_type == "AvgPoolGradCpu")) { | |||||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "Conv2DBackpropFilter") { | } else if (op_type == "Conv2DBackpropFilter") { | ||||
| return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | ||||
| @@ -684,7 +685,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "FlattenGrad") { | } else if (op_type == "FlattenGrad") { | ||||
| return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType); | return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "FusedBatchNormGrad") { | |||||
| } else if ((op_type == "FusedBatchNormGrad") || (op_type == "FusedBatchNormGradCpu")) { | |||||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "PowerGrad") { | } else if (op_type == "PowerGrad") { | ||||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | ||||
| @@ -714,6 +715,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<SigmoidCrossEntropyWithLogits>(prim, inputs, quantType); | return NewPrimitiveC<SigmoidCrossEntropyWithLogits>(prim, inputs, quantType); | ||||
| } else if (op_type == "SigmoidCrossEntropyWithLogitsGrad") { | } else if (op_type == "SigmoidCrossEntropyWithLogitsGrad") { | ||||
| return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType); | return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "Pad") { | |||||
| return NewPrimitiveC<Pad>(prim, inputs, quantType); | |||||
| #else | #else | ||||
| } else if (op_type == "Conv2DBackpropInput") { | } else if (op_type == "Conv2DBackpropInput") { | ||||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | ||||
| @@ -237,6 +237,9 @@ int Convolution1x1CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | ||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| if (IsTrain()) { | |||||
| PackWeight(); | |||||
| } | |||||
| for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | ||||
| output_ptr_ = src_out + batch_index * matmul_param_->row_ * matmul_param_->col_; | output_ptr_ = src_out + batch_index * matmul_param_->row_ * matmul_param_->col_; | ||||
| @@ -261,4 +264,45 @@ int Convolution1x1CPUKernel::Run() { | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void Convolution1x1CPUKernel::PackWeight() { | |||||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | |||||
| auto input_channel = filter_tensor->Channel(); | |||||
| auto output_channel = filter_tensor->Batch(); | |||||
| #ifdef ENABLE_AVX | |||||
| row_tile_ = C6NUM; | |||||
| col_tile_ = C16NUM; | |||||
| #elif defined(ENABLE_SSE) | |||||
| row_tile_ = C4NUM; | |||||
| col_tile_ = C8NUM; | |||||
| #elif defined(ENABLE_ARM32) | |||||
| row_tile_ = C12NUM; | |||||
| col_tile_ = C4NUM; | |||||
| #else | |||||
| row_tile_ = C12NUM; | |||||
| col_tile_ = C8NUM; | |||||
| #endif | |||||
| int size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float); | |||||
| int down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float); | |||||
| memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size); | |||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||||
| input_channel); | |||||
| #elif defined(ENABLE_ARM32) | |||||
| RowMajor2Col4Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||||
| input_channel); | |||||
| #else | |||||
| RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||||
| input_channel); | |||||
| #endif | |||||
| } | |||||
| int Convolution1x1CPUKernel::Eval() { | |||||
| LiteKernel::Eval(); | |||||
| PackWeight(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_ | |||||
| #include <float.h> | #include <float.h> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -42,6 +42,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int Init() override; | int Init() override; | ||||
| int Run() override; | int Run() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int Eval() override; | |||||
| public: | public: | ||||
| int DoConv1x1(int task_id); | int DoConv1x1(int task_id); | ||||
| @@ -53,6 +54,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| void InitConv1x1MatmulParam(); | void InitConv1x1MatmulParam(); | ||||
| void FreeTmpBuffer(); | void FreeTmpBuffer(); | ||||
| void PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col); | void PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void PackWeight(); | |||||
| private: | private: | ||||
| MatMulParameter *matmul_param_ = nullptr; | MatMulParameter *matmul_param_ = nullptr; | ||||
| @@ -70,4 +72,4 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int col_tile_ = 0; | int col_tile_ = 0; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_ | |||||
| @@ -47,6 +47,15 @@ class ConvolutionDelegateCPUKernel : public LiteKernel { | |||||
| static float *CopyData(lite::Tensor *tensor); | static float *CopyData(lite::Tensor *tensor); | ||||
| void FreeCopiedData(); | void FreeCopiedData(); | ||||
| int Eval() override { | |||||
| LiteKernel::Eval(); | |||||
| return conv_kernel_->Eval(); | |||||
| } | |||||
| int Train() override { | |||||
| LiteKernel::Train(); | |||||
| return conv_kernel_->Train(); | |||||
| } | |||||
| protected: | protected: | ||||
| bool need_free_weight_ = false; | bool need_free_weight_ = false; | ||||
| bool need_free_bias_ = false; | bool need_free_bias_ = false; | ||||
| @@ -127,6 +127,10 @@ int ConvolutionDepthwise3x3CPUKernel::Run() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (IsTrain()) { | |||||
| PackWeight(); | |||||
| } | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | auto input_tensor = in_tensors_.at(kInputIndex); | ||||
| input_ptr_ = reinterpret_cast<float *>(input_tensor->data_c()); | input_ptr_ = reinterpret_cast<float *>(input_tensor->data_c()); | ||||
| @@ -146,4 +150,18 @@ int ConvolutionDepthwise3x3CPUKernel::Run() { | |||||
| context_->allocator->Free(buffer_); | context_->allocator->Free(buffer_); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void ConvolutionDepthwise3x3CPUKernel::PackWeight() { | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData()); | |||||
| PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | |||||
| weight_tensor->Batch()); | |||||
| } | |||||
| int ConvolutionDepthwise3x3CPUKernel::Eval() { | |||||
| LiteKernel::Eval(); | |||||
| PackWeight(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -37,8 +37,10 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int InitWeightBias(); | int InitWeightBias(); | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| int Eval() override; | |||||
| private: | private: | ||||
| void PackWeight(); | |||||
| int InitBuffer(); | int InitBuffer(); | ||||
| SlidingWindowParam *sliding_ = nullptr; | SlidingWindowParam *sliding_ = nullptr; | ||||
| float *packed_weight_ = nullptr; | float *packed_weight_ = nullptr; | ||||
| @@ -48,4 +50,4 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" | ||||
| #include <limits> | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| @@ -104,6 +105,10 @@ int ConvDwRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int ConvolutionDepthwiseCPUKernel::Run() { | int ConvolutionDepthwiseCPUKernel::Run() { | ||||
| if (IsTrain()) { | |||||
| PackWeight(); | |||||
| } | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | auto input_tensor = in_tensors_.at(kInputIndex); | ||||
| input_ptr_ = reinterpret_cast<float *>(input_tensor->MutableData()); | input_ptr_ = reinterpret_cast<float *>(input_tensor->MutableData()); | ||||
| @@ -118,6 +123,19 @@ int ConvolutionDepthwiseCPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void ConvolutionDepthwiseCPUKernel::PackWeight() { | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData()); | |||||
| PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), | |||||
| weight_tensor->Batch()); | |||||
| } | |||||
| int ConvolutionDepthwiseCPUKernel::Eval() { | |||||
| LiteKernel::Eval(); | |||||
| PackWeight(); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | ||||
| const InnerContext *ctx, const kernel::KernelKey &desc, | const InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -37,12 +37,14 @@ class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int InitWeightBias(); | int InitWeightBias(); | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| int Eval() override; | |||||
| private: | private: | ||||
| void PackWeight(); | |||||
| float *packed_weight_ = nullptr; | float *packed_weight_ = nullptr; | ||||
| float *input_ptr_ = nullptr; | float *input_ptr_ = nullptr; | ||||
| float *output_ptr_ = nullptr; | float *output_ptr_ = nullptr; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_ | |||||
| @@ -190,6 +190,10 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() { | |||||
| packed_input_ = input_ptr; | packed_input_ = input_ptr; | ||||
| } | } | ||||
| if (IsTrain()) { | |||||
| PackWeight(); | |||||
| } | |||||
| auto output_tensor = out_tensors_.at(kOutputIndex); | auto output_tensor = out_tensors_.at(kOutputIndex); | ||||
| output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c()); | output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c()); | ||||
| @@ -205,4 +209,23 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() { | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void ConvolutionDepthwiseIndirectCPUKernel::PackWeight() { | |||||
| auto weight_tensor = in_tensors_[kWeightIndex]; | |||||
| auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData()); | |||||
| #ifdef ENABLE_AVX | |||||
| PackDepthwiseIndirectWeightC8Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(), | |||||
| weight_tensor->Batch()); | |||||
| #else | |||||
| PackDepthwiseIndirectWeightC4Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(), | |||||
| weight_tensor->Batch()); | |||||
| #endif | |||||
| } | |||||
| int ConvolutionDepthwiseIndirectCPUKernel::Eval() { | |||||
| LiteKernel::Eval(); | |||||
| PackWeight(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -37,10 +37,12 @@ class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int InitWeightBias(); | int InitWeightBias(); | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| int Eval() override; | |||||
| private: | private: | ||||
| int MallocIndirectBuffer(); | int MallocIndirectBuffer(); | ||||
| int MallocPackedInput(); | int MallocPackedInput(); | ||||
| void PackWeight(); | |||||
| int step_w = 0; | int step_w = 0; | ||||
| int step_h = 0; | int step_h = 0; | ||||
| float **indirect_buffer_ = nullptr; | float **indirect_buffer_ = nullptr; | ||||
| @@ -51,4 +53,4 @@ class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_ | |||||
| @@ -145,6 +145,11 @@ int ConvolutionDepthwiseSWCPUKernel::Run() { | |||||
| FreePackedInputOutput(); | FreePackedInputOutput(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (IsTrain()) { | |||||
| PackWeight(); | |||||
| } | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | auto input_tensor = in_tensors_.at(kInputIndex); | ||||
| auto input_ptr = reinterpret_cast<float *>(input_tensor->MutableData()); | auto input_ptr = reinterpret_cast<float *>(input_tensor->MutableData()); | ||||
| @@ -183,4 +188,18 @@ void ConvolutionDepthwiseSWCPUKernel::FreePackedInputOutput() { | |||||
| packed_output_ = nullptr; | packed_output_ = nullptr; | ||||
| } | } | ||||
| } | } | ||||
| void ConvolutionDepthwiseSWCPUKernel::PackWeight() { | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData()); | |||||
| PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(), | |||||
| weight_tensor->Batch()); | |||||
| } | |||||
| int ConvolutionDepthwiseSWCPUKernel::Eval() { | |||||
| LiteKernel::Eval(); | |||||
| PackWeight(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -37,10 +37,12 @@ class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int InitWeightBias(); | int InitWeightBias(); | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| int Eval() override; | |||||
| private: | private: | ||||
| int InitPackedInputOutput(); | int InitPackedInputOutput(); | ||||
| void FreePackedInputOutput(); | void FreePackedInputOutput(); | ||||
| void PackWeight(); | |||||
| SlidingWindowParam *sliding_ = nullptr; | SlidingWindowParam *sliding_ = nullptr; | ||||
| float *packed_weight_ = nullptr; | float *packed_weight_ = nullptr; | ||||
| float *packed_input_ = nullptr; | float *packed_input_ = nullptr; | ||||
| @@ -49,4 +51,4 @@ class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_ | |||||
| @@ -150,6 +150,9 @@ int ConvolutionCPUKernel::Run() { | |||||
| FreeTmpBuffer(); | FreeTmpBuffer(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (IsTrain()) { | |||||
| PackWeight(); | |||||
| } | |||||
| ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionImpl, this, thread_count_); | ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionImpl, this, thread_count_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -158,4 +161,37 @@ int ConvolutionCPUKernel::Run() { | |||||
| FreeTmpBuffer(); | FreeTmpBuffer(); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| void ConvolutionCPUKernel::PackWeight() { | |||||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | |||||
| int in_channel = filter_tensor->Channel(); | |||||
| int out_channel = filter_tensor->Batch(); | |||||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | |||||
| #ifdef ENABLE_AVX | |||||
| const int oc_block = C16NUM; | |||||
| #elif ENABLE_ARM32 | |||||
| const int oc_block = C4NUM; | |||||
| #else | |||||
| const int oc_block = C8NUM; | |||||
| #endif | |||||
| int oc_block_num = UP_ROUND(out_channel, oc_block); | |||||
| int pack_weight_size = oc_block_num * in_channel * kernel_plane; | |||||
| auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c()); | |||||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float)); | |||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||||
| #elif ENABLE_ARM32 | |||||
| RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||||
| #else | |||||
| RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||||
| #endif | |||||
| } | |||||
| int ConvolutionCPUKernel::Eval() { | |||||
| LiteKernel::Eval(); | |||||
| PackWeight(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -46,7 +46,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int Run() override; | int Run() override; | ||||
| virtual int RunImpl(int task_id); | virtual int RunImpl(int task_id); | ||||
| int Eval() override; | |||||
| protected: | protected: | ||||
| void PackWeight(); | |||||
| void FreeTmpBuffer() { | void FreeTmpBuffer() { | ||||
| if (packed_input_ != nullptr) { | if (packed_input_ != nullptr) { | ||||
| ctx_->allocator->Free(packed_input_); | ctx_->allocator->Free(packed_input_); | ||||
| @@ -58,10 +58,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { | |||||
| // set data | // set data | ||||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float); | auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float); | ||||
| trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size)); | |||||
| if (trans_weight_ == nullptr) { | if (trans_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc matrix_buffer failed."; | |||||
| return RET_MEMORY_FAILED; | |||||
| trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size)); | |||||
| if (trans_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc matrix_buffer failed."; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| } | } | ||||
| memset(trans_weight_, 0, trans_matrix_data_size); | memset(trans_weight_, 0, trans_matrix_data_size); | ||||
| @@ -217,6 +219,9 @@ int ConvolutionWinogradCPUKernel::Run() { | |||||
| FreeTmpBuffer(); | FreeTmpBuffer(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (IsTrain()) { | |||||
| InitWeightBias(); | |||||
| } | |||||
| ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionWinogradImpl, this, thread_count_); | ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionWinogradImpl, this, thread_count_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -226,4 +231,11 @@ int ConvolutionWinogradCPUKernel::Run() { | |||||
| FreeTmpBuffer(); | FreeTmpBuffer(); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| int ConvolutionWinogradCPUKernel::Eval() { | |||||
| LiteKernel::Eval(); | |||||
| InitWeightBias(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -43,6 +43,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int Eval() override; | |||||
| int RunImpl(int task_id); | int RunImpl(int task_id); | ||||
| int InitWeightBias(); | int InitWeightBias(); | ||||
| int InitTmpBuffer(); | int InitTmpBuffer(); | ||||
| @@ -84,4 +85,4 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_ | |||||
| @@ -48,33 +48,27 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| int length = in_tensors_.at(0)->ElementsNum(); | int length = in_tensors_.at(0)->ElementsNum(); | ||||
| int stride = UP_DIV(length, 1); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| size_t start = stride * task_id; | |||||
| auto error_code = RET_OK; | auto error_code = RET_OK; | ||||
| if (param_act_grad_->type_ == schema::ActivationType_RELU) { | if (param_act_grad_->type_ == schema::ActivationType_RELU) { | ||||
| error_code = | |||||
| ReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| error_code = ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_RELU6) { | } else if (param_act_grad_->type_ == schema::ActivationType_RELU6) { | ||||
| error_code = | |||||
| Relu6Grad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| error_code = Relu6Grad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) { | } else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) { | ||||
| error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, | |||||
| output_addr + stride * task_id, param_act_grad_->alpha_); | |||||
| error_code = LReluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | ||||
| // Sigmoid gets the input tensors in reverse order! | // Sigmoid gets the input tensors in reverse order! | ||||
| error_code = | |||||
| SigmoidGrad(input_addr + stride * task_id, yt_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | ||||
| error_code = | |||||
| TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| error_code = TanhGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) { | } else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) { | ||||
| error_code = | |||||
| HSwishGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { | } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { | ||||
| error_code = | |||||
| HSigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Activation type error"; | MS_LOG(ERROR) << "Activation type error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -97,7 +91,7 @@ int ActivationGradRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int ActivationGradCPUKernel::Run() { | int ActivationGradCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,7 +27,7 @@ class ActivationGradCPUKernel : public LiteKernel { | |||||
| explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, | explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(param, inputs, outputs, ctx, primitive) { | |||||
| : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { | |||||
| param_act_grad_ = reinterpret_cast<ActivationParameter *>(param); | param_act_grad_ = reinterpret_cast<ActivationParameter *>(param); | ||||
| } | } | ||||
| ~ActivationGradCPUKernel() override = default; | ~ActivationGradCPUKernel() override = default; | ||||
| @@ -39,6 +39,7 @@ class ActivationGradCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| ActivationParameter *param_act_grad_; | ActivationParameter *param_act_grad_; | ||||
| int thread_count_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -33,17 +33,23 @@ namespace mindspore::kernel { | |||||
| int AdamCPUKernel::ReSize() { return RET_OK; } | int AdamCPUKernel::ReSize() { return RET_OK; } | ||||
| int AdamCPUKernel::Execute(int task_id) { | int AdamCPUKernel::Execute(int task_id) { | ||||
| auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| auto m = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||||
| auto v = reinterpret_cast<float *>(in_tensors_[2]->MutableData()); | |||||
| auto beta1_power = reinterpret_cast<float *>(in_tensors_[3]->MutableData())[0]; | |||||
| auto beta2_power = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0]; | |||||
| auto learning_rate = reinterpret_cast<float *>(in_tensors_[5]->MutableData())[0]; | |||||
| auto beta1 = reinterpret_cast<float *>(in_tensors_[6]->MutableData())[0]; | |||||
| auto beta2 = reinterpret_cast<float *>(in_tensors_[7]->MutableData())[0]; | |||||
| auto eps = reinterpret_cast<float *>(in_tensors_[8]->MutableData())[0]; | |||||
| auto gradient = reinterpret_cast<float *>(in_tensors_[9]->MutableData()); | |||||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||||
| auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto m = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||||
| auto v = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| auto beta1_power = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData())[0]; | |||||
| auto beta2_power = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | |||||
| auto learning_rate = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData())[0]; | |||||
| auto beta1 = reinterpret_cast<float *>(in_tensors_.at(6)->MutableData())[0]; | |||||
| auto beta2 = reinterpret_cast<float *>(in_tensors_.at(7)->MutableData())[0]; | |||||
| auto eps = reinterpret_cast<float *>(in_tensors_.at(8)->MutableData())[0]; | |||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(9)->MutableData()); | |||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| if ((1.f - beta1_power) <= 0.0f) { | if ((1.f - beta1_power) <= 0.0f) { | ||||
| MS_LOG(ERROR) << "divisor cannot be 0 or below"; | MS_LOG(ERROR) << "divisor cannot be 0 or below"; | ||||
| @@ -55,17 +61,19 @@ int AdamCPUKernel::Execute(int task_id) { | |||||
| } | } | ||||
| auto update_lr = learning_rate * std::sqrt(1.f - beta2_power) / (1.f - beta1_power); | auto update_lr = learning_rate * std::sqrt(1.f - beta2_power) / (1.f - beta1_power); | ||||
| const float one_minus_beta1 = 1.f - beta1; | |||||
| const float one_minus_beta2 = 1.f - beta2; | |||||
| if (adam_param_->use_nesterov_) { // Nadam | if (adam_param_->use_nesterov_) { // Nadam | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| m[i] += (gradient[i] - m[i]) * (1.f - beta1); | |||||
| v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2); | |||||
| weight[i] -= update_lr * (m[i] * beta1 + (1.f - beta1) * gradient[i]) / (std::sqrt(v[i]) + eps); | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| m[i] += (gradient[i] - m[i]) * one_minus_beta1; | |||||
| v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; | |||||
| weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (std::sqrt(v[i]) + eps); | |||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| m[i] += (gradient[i] - m[i]) * (1.f - beta1); | |||||
| v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2); | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| m[i] += (gradient[i] - m[i]) * one_minus_beta1; | |||||
| v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; | |||||
| weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps); | weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps); | ||||
| } | } | ||||
| } | } | ||||
| @@ -84,7 +92,7 @@ int AdamRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int AdamCPUKernel::Run() { | int AdamCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -92,6 +100,17 @@ int AdamCPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int AdamCPUKernel::SetLearningRate(float lr) { | |||||
| auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData()); | |||||
| learning_rate_tensor[0] = lr; | |||||
| return RET_OK; | |||||
| } | |||||
| float AdamCPUKernel::GetLearningRate() { | |||||
| auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData()); | |||||
| return learning_rate_tensor[0]; | |||||
| } | |||||
| int AdamCPUKernel::Init() { return RET_OK; } | int AdamCPUKernel::Init() { return RET_OK; } | ||||
| kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| @@ -18,25 +18,28 @@ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | |||||
| #include "src/train/optimizer_kernel.h" | |||||
| #include "nnacl/fp32_grad/optimizer.h" | #include "nnacl/fp32_grad/optimizer.h" | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class AdamCPUKernel : public LiteKernel { | |||||
| class AdamCPUKernel : public OptimizerKernel { | |||||
| public: | public: | ||||
| explicit AdamCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit AdamCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| : OptimizerKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { | |||||
| adam_param_ = reinterpret_cast<AdamParameter *>(parameter); | adam_param_ = reinterpret_cast<AdamParameter *>(parameter); | ||||
| } | } | ||||
| ~AdamCPUKernel() override {} | ~AdamCPUKernel() override {} | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int SetLearningRate(float lr) override; | |||||
| float GetLearningRate() override; | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| int thread_count_; | |||||
| AdamParameter *adam_param_; | AdamParameter *adam_param_; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -31,20 +31,26 @@ namespace mindspore::kernel { | |||||
| int ApplyMomentumCPUKernel::ReSize() { return RET_OK; } | int ApplyMomentumCPUKernel::ReSize() { return RET_OK; } | ||||
| int ApplyMomentumCPUKernel::Execute(int task_id) { | int ApplyMomentumCPUKernel::Execute(int task_id) { | ||||
| auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||||
| float learning_rate = reinterpret_cast<float *>(in_tensors_[2]->MutableData())[0]; | |||||
| auto gradient = reinterpret_cast<float *>(in_tensors_[3]->MutableData()); | |||||
| float moment = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0]; | |||||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||||
| auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto accumulate = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||||
| float learning_rate = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData())[0]; | |||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData()); | |||||
| float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | |||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| if (apply_momentum_param_->use_nesterov_) { | if (apply_momentum_param_->use_nesterov_) { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | accumulate[i] = accumulate[i] * moment + gradient[i]; | ||||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | accumulate[i] = accumulate[i] * moment + gradient[i]; | ||||
| weight[i] -= accumulate[i] * learning_rate; | weight[i] -= accumulate[i] * learning_rate; | ||||
| } | } | ||||
| @@ -64,7 +70,7 @@ int ApplyMomentumRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int ApplyMomentumCPUKernel::Run() { | int ApplyMomentumCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -74,6 +80,17 @@ int ApplyMomentumCPUKernel::Run() { | |||||
| int ApplyMomentumCPUKernel::Init() { return RET_OK; } | int ApplyMomentumCPUKernel::Init() { return RET_OK; } | ||||
| int ApplyMomentumCPUKernel::SetLearningRate(float lr) { | |||||
| auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| learning_rate_tensor[0] = lr; | |||||
| return RET_OK; | |||||
| } | |||||
| float ApplyMomentumCPUKernel::GetLearningRate() { | |||||
| auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| return learning_rate_tensor[0]; | |||||
| } | |||||
| kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, | const std::vector<lite::Tensor *> &outputs, | ||||
| OpParameter *opParameter, const lite::InnerContext *ctx, | OpParameter *opParameter, const lite::InnerContext *ctx, | ||||
| @@ -18,16 +18,18 @@ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_APPLY_MOMENTUM_H_ | #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_APPLY_MOMENTUM_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | |||||
| #include "src/train/optimizer_kernel.h" | |||||
| #include "nnacl/fp32_grad/optimizer.h" | #include "nnacl/fp32_grad/optimizer.h" | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class ApplyMomentumCPUKernel : public LiteKernel { | |||||
| class ApplyMomentumCPUKernel : public OptimizerKernel { | |||||
| public: | public: | ||||
| explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), apply_momentum_param_(nullptr) { | |||||
| : OptimizerKernel(parameter, inputs, outputs, ctx, primitive), | |||||
| thread_count_(ctx->thread_num_), | |||||
| apply_momentum_param_(nullptr) { | |||||
| apply_momentum_param_ = reinterpret_cast<ApplyMomentumParameter *>(parameter); | apply_momentum_param_ = reinterpret_cast<ApplyMomentumParameter *>(parameter); | ||||
| } | } | ||||
| ~ApplyMomentumCPUKernel() override {} | ~ApplyMomentumCPUKernel() override {} | ||||
| @@ -35,8 +37,11 @@ class ApplyMomentumCPUKernel : public LiteKernel { | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| int SetLearningRate(float lr) override; | |||||
| float GetLearningRate() override; | |||||
| private: | private: | ||||
| int thread_count_; | |||||
| ApplyMomentumParameter *apply_momentum_param_; | ApplyMomentumParameter *apply_momentum_param_; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -49,27 +49,24 @@ int ArithmeticSelfGradCPUKernel::Init() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int thread_id) { | |||||
| auto dy = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| auto in_x = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||||
| auto dx = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||||
| int dy_size = in_tensors_.at(0)->ElementsNum(); | |||||
| int size = MSMIN(thread_stride_, static_cast<int>(dy_size - thread_id * thread_stride_)); | |||||
| if (size <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| int offset = thread_id * thread_stride_; | |||||
| (*self_grad_operation_)(dy + offset, in_x + offset, dx + offset, size); | |||||
| int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int task_id) { | |||||
| auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto in_x = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||||
| auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| (*self_grad_operation_)(dy + start, in_x + start, dx + start, count); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ArithmeticSelfGradCPUKernel::ReSize() { return RET_OK; } | int ArithmeticSelfGradCPUKernel::ReSize() { return RET_OK; } | ||||
| int ArithmeticSelfGradCPUKernel::Run() { | int ArithmeticSelfGradCPUKernel::Run() { | ||||
| int dy_size = in_tensors_.at(0)->ElementsNum(); | |||||
| op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, static_cast<int>(dy_size)); | |||||
| thread_stride_ = UP_DIV(dy_size, op_parameter_->thread_num_); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfGradRun, this, op_parameter_->thread_num_); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfGradRun, this, thread_count_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "parallel launch fail!ret: " << ret; | MS_LOG(ERROR) << "parallel launch fail!ret: " << ret; | ||||
| return ret; | return ret; | ||||
| @@ -30,7 +30,7 @@ class ArithmeticSelfGradCPUKernel : public LiteKernel { | |||||
| ArithmeticSelfGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ArithmeticSelfGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} | |||||
| ~ArithmeticSelfGradCPUKernel() override {} | ~ArithmeticSelfGradCPUKernel() override {} | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| @@ -38,7 +38,7 @@ class ArithmeticSelfGradCPUKernel : public LiteKernel { | |||||
| int DoArithmeticSelfGrad(int thread_id); | int DoArithmeticSelfGrad(int thread_id); | ||||
| private: | private: | ||||
| int thread_stride_; | |||||
| int thread_count_; | |||||
| ArithmeticSelfGradOperation self_grad_operation_; | ArithmeticSelfGradOperation self_grad_operation_; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -32,11 +32,16 @@ namespace mindspore::kernel { | |||||
| int AssignCPUKernel::ReSize() { return RET_OK; } | int AssignCPUKernel::ReSize() { return RET_OK; } | ||||
| int AssignCPUKernel::Execute(int task_id) { | int AssignCPUKernel::Execute(int task_id) { | ||||
| auto x = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| auto y = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||||
| size_t size = in_tensors_[0]->Size(); | |||||
| auto x = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto y = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| memcpy(x, y, size); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| memcpy(&(x[start]), &(y[start]), count * sizeof(float)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -52,7 +57,7 @@ int AssignRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int AssignCPUKernel::Run() { | int AssignCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,12 +27,15 @@ class AssignCPUKernel : public LiteKernel { | |||||
| explicit AssignCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit AssignCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} | |||||
| ~AssignCPUKernel() override {} | ~AssignCPUKernel() override {} | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| protected: | |||||
| int thread_count_ = 1; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_BiasGrad; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int BiasGradCPUKernel::Init() { | |||||
| int BiasGradCPUKernel::ReSize() { | |||||
| auto dims = in_tensors_[0]->shape(); | auto dims = in_tensors_[0]->shape(); | ||||
| bias_param->ndim_ = dims.size(); | bias_param->ndim_ = dims.size(); | ||||
| for (unsigned int i = 0; i < bias_param->ndim_; i++) { | for (unsigned int i = 0; i < bias_param->ndim_; i++) { | ||||
| @@ -44,7 +44,12 @@ int BiasGradCPUKernel::Init() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int BiasGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int BiasGradCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int BiasGradCPUKernel::Execute(int task_id) { | int BiasGradCPUKernel::Execute(int task_id) { | ||||
| auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| @@ -31,17 +31,16 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_BNGrad; | using mindspore::schema::PrimitiveType_BNGrad; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int BNGradCPUKernel::Init() { | |||||
| int BNGradCPUKernel::ReSize() { | |||||
| auto *input_x = in_tensors_.at(1); | auto *input_x = in_tensors_.at(1); | ||||
| int channels = input_x->shape().at(kNHWC_C); | int channels = input_x->shape().at(kNHWC_C); | ||||
| set_workspace_size(2 * channels * sizeof(float)); | set_workspace_size(2 * channels * sizeof(float)); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int BNGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int BNGradCPUKernel::Init() { return ReSize(); } | |||||
| int BNGradCPUKernel::Execute(int task_id) { | int BNGradCPUKernel::Execute(int task_id) { | ||||
| auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_); | |||||
| auto *input_yt = in_tensors_.at(0); | auto *input_yt = in_tensors_.at(0); | ||||
| auto *input_x = in_tensors_.at(1); | auto *input_x = in_tensors_.at(1); | ||||
| auto *input_scale = in_tensors_.at(2); | auto *input_scale = in_tensors_.at(2); | ||||
| @@ -54,10 +53,9 @@ int BNGradCPUKernel::Execute(int task_id) { | |||||
| auto *output_dx = out_tensors_.at(0); | auto *output_dx = out_tensors_.at(0); | ||||
| auto *output_scale = out_tensors_.at(1); | auto *output_scale = out_tensors_.at(1); | ||||
| auto *output_bias = out_tensors_.at(2); | auto *output_bias = out_tensors_.at(2); | ||||
| size_t batch = input_x->Batch(); | |||||
| size_t channels = input_x->Channel(); | |||||
| size_t spatial = input_x->Height() * input_x->Width(); | |||||
| float eps = bn_param->epsilon_; | |||||
| int32_t batch = input_x->Batch(); | |||||
| int32_t channels = input_x->Channel(); | |||||
| int32_t spatial = input_x->Height() * input_x->Width(); | |||||
| float *workspace_temp = static_cast<float *>(workspace()); | float *workspace_temp = static_cast<float *>(workspace()); | ||||
| std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f); | std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f); | ||||
| @@ -68,34 +66,32 @@ int BNGradCPUKernel::Execute(int task_id) { | |||||
| float *yt = reinterpret_cast<float *>(input_yt->MutableData()); | float *yt = reinterpret_cast<float *>(input_yt->MutableData()); | ||||
| float *scale = reinterpret_cast<float *>(input_scale->MutableData()); | float *scale = reinterpret_cast<float *>(input_scale->MutableData()); | ||||
| float *dx = reinterpret_cast<float *>(output_dx->MutableData()); | float *dx = reinterpret_cast<float *>(output_dx->MutableData()); | ||||
| float *dscale = reinterpret_cast<float *>(output_scale->MutableData()); | |||||
| float *dbias = reinterpret_cast<float *>(output_bias->MutableData()); | float *dbias = reinterpret_cast<float *>(output_bias->MutableData()); | ||||
| var2Invar(save_var, input_var->ElementsNum(), eps); | |||||
| // dx | |||||
| backwardX(x, yt, scale, batch * spatial, channels, save_mean, save_var, dxhat_sum, dxhathat_sum, dx); | |||||
| // dbias | |||||
| sumSpatialBatch(yt, batch * spatial, channels, dbias); | |||||
| // dscale | |||||
| backwardScale(x, save_mean, save_var, yt, batch, channels, spatial, dscale); | |||||
| float *dscale = reinterpret_cast<float *>(output_scale->MutableData()); | |||||
| std::fill(dbias, dbias + channels, 0.f); | |||||
| std::fill(dscale, dscale + channels, 0.f); | |||||
| backwardAll(x, yt, save_mean, save_var, scale, batch * spatial, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int BNGradRun(void *cdata, int task_id) { | int BNGradRun(void *cdata, int task_id) { | ||||
| MS_ASSERT(cdata != nullptr); | MS_ASSERT(cdata != nullptr); | ||||
| auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata); | auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata); | ||||
| if (task_id == 0) { | |||||
| auto error_code = bn_kernel->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto error_code = bn_kernel->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int BNGradCPUKernel::Run() { | int BNGradCPUKernel::Run() { | ||||
| auto *input_var = in_tensors_.at(4); | |||||
| float *save_var = reinterpret_cast<float *>(input_var->MutableData()); | |||||
| auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_); | |||||
| float eps = bn_param->epsilon_; | |||||
| var2Invar(save_var, input_var->ElementsNum(), eps); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); | int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; | ||||
| @@ -26,7 +26,7 @@ using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionTrainCPUKernel::Init() { | |||||
| int ConvolutionTrainCPUKernel::ReSize() { | |||||
| if (in_tensors_.size() < 2) { | if (in_tensors_.size() < 2) { | ||||
| MS_LOG(ERROR) << "Convolution should have at least two inputs"; | MS_LOG(ERROR) << "Convolution should have at least two inputs"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -54,13 +54,21 @@ int ConvolutionTrainCPUKernel::Init() { | |||||
| conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_; | conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_; | ||||
| const int n = conv_param_->output_channel_ * conv_param_->group_; | const int n = conv_param_->output_channel_ * conv_param_->group_; | ||||
| const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; | const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; | ||||
| ws_size = chunk * k; | |||||
| int mat_alloc = MatSizeTotal(chunk, n, k, 0); | |||||
| set_workspace_size((ws_size + mat_alloc) * sizeof(float)); | |||||
| ws_size_ = chunk_ * k; | |||||
| int mat_alloc = MatSizeTotal(chunk_, n, k, 0); | |||||
| set_workspace_size((ws_size_ + mat_alloc) * sizeof(float)); | |||||
| do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && | |||||
| (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && | |||||
| (conv_param_->dilation_h_ == 1) && (conv_param_->dilation_w_ == 1) && | |||||
| (conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1) | |||||
| ? false | |||||
| : true; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionTrainCPUKernel::ReSize() { return RET_OK; } | |||||
| int ConvolutionTrainCPUKernel::Init() { return ReSize(); } | |||||
| int ConvolutionTrainCPUKernel::Execute(int task_id) { | int ConvolutionTrainCPUKernel::Execute(int task_id) { | ||||
| auto conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_); | auto conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_); | ||||
| @@ -87,17 +95,34 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) { | |||||
| const int n = out_ch / groups; | const int n = out_ch / groups; | ||||
| const int k = k_h * k_w * in_ch / groups; | const int k = k_h * k_w * in_ch / groups; | ||||
| float *workspace_temp = static_cast<float *>(workspace()); | float *workspace_temp = static_cast<float *>(workspace()); | ||||
| float *mat_workspace = workspace_temp + ws_size; | |||||
| for (int i = 0; i < batch; ++i) { | |||||
| for (int j = 0; j < groups; ++j) { | |||||
| for (int ci = 0; ci < m; ci += chunk) { | |||||
| int real_chunk = MSMIN(m - ci, chunk); | |||||
| float *mat_a = workspace_temp; | |||||
| const float *mat_b = w_addr + j * nweights / groups; | |||||
| float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch; | |||||
| float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups); | |||||
| RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci); | |||||
| GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a, k, mat_b, k, 0, mat_c, out_ch, mat_workspace); | |||||
| float *mat_workspace = workspace_temp + ws_size_; | |||||
| if (do_img2col_) { | |||||
| for (int i = 0; i < batch; ++i) { | |||||
| for (int j = 0; j < groups; ++j) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | |||||
| float *mat_a = workspace_temp; | |||||
| const float *mat_b = w_addr + j * nweights / groups; | |||||
| float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch; | |||||
| float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups); | |||||
| RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci); | |||||
| GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a, k, mat_b, k, 0, mat_c, out_ch, mat_workspace); | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| const float *mat_b = w_addr; | |||||
| const size_t in_plane_size = in_ch * in_h * in_w; | |||||
| for (int i = 0; i < batch; ++i) { | |||||
| float *im = x_addr + i * in_plane_size; | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | |||||
| float *mat_c = y_addr + i * n * m + ci * out_ch; | |||||
| int input_height = ci / out_w * conv_param_->stride_h_; | |||||
| int input_width = ci % out_w * conv_param_->stride_w_; | |||||
| int offset = (input_height * in_w + input_width) * in_ch; | |||||
| GemmMatmul(0, 1, real_chunk, n, k, 1, im + offset, k, mat_b, k, 0, mat_c, out_ch, mat_workspace); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -35,11 +35,12 @@ class ConvolutionTrainCPUKernel : public LiteKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| int ws_size = 0; | |||||
| int ws_size_ = 0; | |||||
| bool do_img2col_ = true; | |||||
| #ifdef ENABLE_ARM32 | #ifdef ENABLE_ARM32 | ||||
| const int chunk = C4NUM; | |||||
| const int chunk_ = C4NUM * 2; | |||||
| #else | #else | ||||
| const int chunk = C12NUM; | |||||
| const int chunk_ = C12NUM * 2; | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -29,7 +29,7 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Conv2DGradFilter; | using mindspore::schema::PrimitiveType_Conv2DGradFilter; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionGradFilterCPUKernel::Init() { | |||||
| int ConvolutionGradFilterCPUKernel::ReSize() { | |||||
| // dy is in input 0 | // dy is in input 0 | ||||
| // x is in input 1 | // x is in input 1 | ||||
| // dw is output 0 | // dw is output 0 | ||||
| @@ -51,16 +51,25 @@ int ConvolutionGradFilterCPUKernel::Init() { | |||||
| conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | ||||
| conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | ||||
| ws_size = chunk * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | |||||
| ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | |||||
| int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | ||||
| int k = conv_param->output_channel_ / conv_param->group_; | int k = conv_param->output_channel_ / conv_param->group_; | ||||
| size_t mat_alloc = MatSizeTotal(k, n, chunk, n); | |||||
| set_workspace_size((ws_size + mat_alloc) * sizeof(float)); | |||||
| int thread_num = context_->thread_num_; | |||||
| mat_alloc_ = MatSizeTotal(k, n, chunk_, 0); | |||||
| set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float)); | |||||
| do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && | |||||
| (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && | |||||
| (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && | |||||
| (conv_param->stride_w_ == 1) && (conv_param->group_ == 1) | |||||
| ? false | |||||
| : true; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionGradFilterCPUKernel::ReSize() { return RET_OK; } | |||||
| int ConvolutionGradFilterCPUKernel::Init() { return ReSize(); } | |||||
| int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | ||||
| @@ -72,7 +81,6 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | |||||
| auto dy_addr = reinterpret_cast<float *>(input_dy->MutableData()); | auto dy_addr = reinterpret_cast<float *>(input_dy->MutableData()); | ||||
| auto dw_addr = reinterpret_cast<float *>(out_dw->MutableData()); | auto dw_addr = reinterpret_cast<float *>(out_dw->MutableData()); | ||||
| int i, j; | |||||
| int nweights = out_dw->ElementsNum(); | int nweights = out_dw->ElementsNum(); | ||||
| int in_ch = conv_param->input_channel_; | int in_ch = conv_param->input_channel_; | ||||
| int in_h = conv_param->input_h_; | int in_h = conv_param->input_h_; | ||||
| @@ -88,22 +96,45 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | |||||
| int m = out_h * out_w; | int m = out_h * out_w; | ||||
| int n = k_h * k_w * in_ch / groups; | int n = k_h * k_w * in_ch / groups; | ||||
| int k = out_ch / groups; | int k = out_ch / groups; | ||||
| int thread_num = context_->thread_num_; | |||||
| float *workspace_temp = reinterpret_cast<float *>(workspace()); | float *workspace_temp = reinterpret_cast<float *>(workspace()); | ||||
| float *mat_workspace = workspace_temp + ws_size; | |||||
| // zero out pointer | |||||
| memset(dw_addr, 0, out_dw->Size()); | |||||
| for (i = 0; i < batch; ++i) { | |||||
| for (j = 0; j < groups; ++j) { | |||||
| for (int ci = 0; ci < m; ci += chunk) { | |||||
| int real_chunk = MSMIN(m - ci, chunk); | |||||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; | |||||
| float *mat_b = workspace_temp; | |||||
| float *mat_c = dw_addr + j * nweights / groups; | |||||
| float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups); | |||||
| memset(mat_b, 0, n * real_chunk * sizeof(float)); | |||||
| RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci); | |||||
| GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 1, mat_c, n, mat_workspace); | |||||
| float *mat_workspace = workspace_temp + ws_size_ * thread_num + task_id * (mat_alloc_ + k * n); | |||||
| float *mat_tmp = mat_workspace + mat_alloc_; | |||||
| int stride = UP_DIV(batch, thread_num); | |||||
| int count = MSMIN(stride, batch - stride * task_id); | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| if (do_img2col_) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| for (int j = 0; j < groups; ++j) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | |||||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; | |||||
| float *mat_b = workspace_temp + task_id * ws_size_; | |||||
| float *mat_c = dw_addr + j * nweights / groups; | |||||
| float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups); | |||||
| RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci); | |||||
| GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 0, mat_tmp, n, mat_workspace); | |||||
| std::unique_lock<std::mutex> merge_lock(lock_); | |||||
| AddMatrix(mat_tmp, mat_c, 1, k, n, n); | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| float *mat_c = dw_addr; | |||||
| const size_t in_plane_size = in_ch * in_h * in_w; | |||||
| for (int i = start; i < end; ++i) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | |||||
| float *mat_a = dy_addr + i * m * k + ci * out_ch; | |||||
| float *im = x_addr + i * in_plane_size; | |||||
| int input_h = ci / out_w * conv_param->stride_h_; | |||||
| int input_w = ci % out_w * conv_param->stride_w_; | |||||
| int offset = (input_h * in_w + input_w) * in_ch; | |||||
| GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, im + offset, n, 0, mat_tmp, n, mat_workspace); | |||||
| std::unique_lock<std::mutex> merge_lock(lock_); | |||||
| AddMatrix(mat_tmp, mat_c, 1, k, n, n); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -122,7 +153,10 @@ int ConvolutionGradFilterRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int ConvolutionGradFilterCPUKernel::Run() { | int ConvolutionGradFilterCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, 1); | |||||
| auto *out_dw = out_tensors_.at(0); | |||||
| auto dw_addr = reinterpret_cast<float *>(out_dw->MutableData()); | |||||
| memset(dw_addr, 0, out_dw->Size()); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, context_->thread_num_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -36,11 +36,14 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| size_t ws_size = 0; | |||||
| size_t ws_size_ = 0; | |||||
| bool do_img2col_ = true; | |||||
| std::mutex lock_; | |||||
| size_t mat_alloc_ = 0; | |||||
| #ifdef ENABLE_ARM32 | #ifdef ENABLE_ARM32 | ||||
| const int chunk = C4NUM; | |||||
| const int chunk_ = C4NUM * 2; | |||||
| #else | #else | ||||
| const int chunk = C12NUM; | |||||
| const int chunk_ = C12NUM * 2; | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -30,7 +30,7 @@ using mindspore::schema::PrimitiveType_Conv2DGradInput; | |||||
| using mindspore::schema::PrimitiveType_GroupConv2DGradInput; | using mindspore::schema::PrimitiveType_GroupConv2DGradInput; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionGradInputCPUKernel::Init() { | |||||
| int ConvolutionGradInputCPUKernel::ReSize() { | |||||
| auto *dy_tensor = in_tensors_.at(kInputIndex); | auto *dy_tensor = in_tensors_.at(kInputIndex); | ||||
| MS_ASSERT(dy_tensor != nullptr); | MS_ASSERT(dy_tensor != nullptr); | ||||
| auto *weight_tensor = in_tensors_.at(kWeightIndex); | auto *weight_tensor = in_tensors_.at(kWeightIndex); | ||||
| @@ -51,18 +51,17 @@ int ConvolutionGradInputCPUKernel::Init() { | |||||
| conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | ||||
| conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | ||||
| ws_size = chunk * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | |||||
| ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | |||||
| int n = conv_param->kernel_w_ * conv_param->kernel_h_ * conv_param->input_channel_ / conv_param->group_; | int n = conv_param->kernel_w_ * conv_param->kernel_h_ * conv_param->input_channel_ / conv_param->group_; | ||||
| int k = conv_param->output_channel_ / conv_param->group_; | int k = conv_param->output_channel_ / conv_param->group_; | ||||
| size_t mat_alloc = MatSizeTotal(chunk, n, k, 0); | |||||
| set_workspace_size((ws_size + mat_alloc) * sizeof(float)); | |||||
| int thread_num = context_->thread_num_; | |||||
| mat_alloc_ = MatSizeTotal(chunk_, n, k, 0); | |||||
| set_workspace_size((ws_size_ + mat_alloc_) * sizeof(float) * thread_num); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionGradInputCPUKernel::ReSize() { return RET_OK; } | |||||
| int ConvolutionGradInputCPUKernel::Init() { return ReSize(); } | |||||
| int ConvolutionGradInputCPUKernel::Execute(int task_id) { | int ConvolutionGradInputCPUKernel::Execute(int task_id) { | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | ||||
| @@ -86,17 +85,21 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) { | |||||
| int groups = conv_param->group_; | int groups = conv_param->group_; | ||||
| int out_h = conv_param->output_h_; | int out_h = conv_param->output_h_; | ||||
| int out_w = conv_param->output_w_; | int out_w = conv_param->output_w_; | ||||
| int thread_num = context_->thread_num_; | |||||
| int m = out_h * out_w; | int m = out_h * out_w; | ||||
| int n = k_w * k_h * in_ch / groups; | int n = k_w * k_h * in_ch / groups; | ||||
| int k = out_ch / groups; | int k = out_ch / groups; | ||||
| float *workspace_temp = reinterpret_cast<float *>(workspace()); | |||||
| float *mat_workspace = workspace_temp + ws_size; | |||||
| memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); | |||||
| for (i = 0; i < batch; ++i) { | |||||
| float *workspace_temp = reinterpret_cast<float *>(workspace()) + task_id * (mat_alloc_ + ws_size_); | |||||
| float *mat_workspace = workspace_temp + ws_size_; | |||||
| int stride = UP_DIV(batch, thread_num); | |||||
| int count = MSMIN(stride, batch - stride * task_id); | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| for (i = start; i < end; ++i) { | |||||
| for (j = 0; j < groups; ++j) { | for (j = 0; j < groups; ++j) { | ||||
| GemmCb gcb; | GemmCb gcb; | ||||
| for (int ci = 0; ci < m; ci += chunk) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| float *mat_b = nullptr; | float *mat_b = nullptr; | ||||
| if (ci == 0) { | if (ci == 0) { | ||||
| mat_b = w_addr + j * nweights / groups; | mat_b = w_addr + j * nweights / groups; | ||||
| @@ -108,7 +111,7 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) { | |||||
| mat_b = gcb.mat_b; | mat_b = gcb.mat_b; | ||||
| gcb.cb = 1; | gcb.cb = 1; | ||||
| } | } | ||||
| int real_chunk = MSMIN(m - ci, chunk); | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | |||||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; | float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; | ||||
| float *mat_c = workspace_temp; | float *mat_c = workspace_temp; | ||||
| GemmMatmulPlus(0, 0, real_chunk, n, k, 1, mat_a, out_ch, mat_b, n, 0, mat_c, n, mat_workspace, &gcb); | GemmMatmulPlus(0, 0, real_chunk, n, k, 1, mat_a, out_ch, mat_b, n, 0, mat_c, n, mat_workspace, &gcb); | ||||
| @@ -133,7 +136,15 @@ int ConvolutionGradInputRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int ConvolutionGradInputCPUKernel::Run() { | int ConvolutionGradInputCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, 1); | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||||
| int batch = conv_param->output_batch_; | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| auto *out_dx = out_tensors_.at(0); | |||||
| auto dx_addr = reinterpret_cast<float *>(out_dx->MutableData()); | |||||
| memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, context_->thread_num_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -35,11 +35,12 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| size_t ws_size = 0; | |||||
| size_t ws_size_ = 0; | |||||
| size_t mat_alloc_ = 0; | |||||
| #ifdef ENABLE_ARM32 | #ifdef ENABLE_ARM32 | ||||
| const int chunk = C4NUM; | |||||
| const int chunk_ = C4NUM; | |||||
| #else | #else | ||||
| const int chunk = C12NUM; | |||||
| const int chunk_ = C12NUM; | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -61,19 +61,26 @@ int DropoutCPUKernel::Execute(int task_id) { | |||||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData()); | auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | ||||
| auto mask = reinterpret_cast<float *>(out_tensors_.at(1)->MutableData()); | auto mask = reinterpret_cast<float *>(out_tensors_.at(1)->MutableData()); | ||||
| auto length = in_tensors_.at(kInputIndex)->ElementsNum(); | |||||
| auto param = reinterpret_cast<DropoutParameter *>(op_parameter_); | auto param = reinterpret_cast<DropoutParameter *>(op_parameter_); | ||||
| auto length = in_tensors_.at(kInputIndex)->ElementsNum(); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "Dropout op_parameter_ nullptr"; | MS_LOG(ERROR) << "Dropout op_parameter_ nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (IsEval()) { | if (IsEval()) { | ||||
| std::copy(input_ptr, input_ptr + length, output_ptr); | |||||
| std::copy(&(input_ptr[start]), &(input_ptr[end]), &(output_ptr[start])); | |||||
| } else { | } else { | ||||
| std::default_random_engine generator; | std::default_random_engine generator; | ||||
| std::bernoulli_distribution distribution(param->ratio_); | std::bernoulli_distribution distribution(param->ratio_); | ||||
| for (int i = 0; i < length; i++) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| mask[i] = distribution(generator); | mask[i] = distribution(generator); | ||||
| output_ptr[i] = input_ptr[i] * mask[i] * scale_; | output_ptr[i] = input_ptr[i] * mask[i] * scale_; | ||||
| } | } | ||||
| @@ -92,7 +99,7 @@ int RunDropout(void *cdata, int task_id) { | |||||
| } | } | ||||
| int DropoutCPUKernel::Run() { | int DropoutCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropout, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropout, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "Dropout function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "Dropout function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -25,7 +25,7 @@ class DropoutCPUKernel : public LiteKernel { | |||||
| DropoutCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | DropoutCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} | |||||
| ~DropoutCPUKernel() override = default; | ~DropoutCPUKernel() override = default; | ||||
| @@ -35,7 +35,8 @@ class DropoutCPUKernel : public LiteKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| float scale_; | |||||
| float scale_ = 1.0; | |||||
| int thread_count_ = 1; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -62,7 +62,13 @@ int DropoutGradCPUKernel::Execute(int task_id) { | |||||
| auto mask_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto mask_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | ||||
| auto length = in_tensors_.at(kInputIndex)->ElementsNum(); | auto length = in_tensors_.at(kInputIndex)->ElementsNum(); | ||||
| DropoutGrad(yt_ptr, mask_ptr, output_ptr, length, scale_); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -78,7 +84,7 @@ int RunDropoutGrad(void *cdata, int task_id) { | |||||
| } | } | ||||
| int DropoutGradCPUKernel::Run() { | int DropoutGradCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropoutGrad, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropoutGrad, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "Dropout Grad function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "Dropout Grad function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -25,7 +25,7 @@ class DropoutGradCPUKernel : public LiteKernel { | |||||
| DropoutGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | DropoutGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} | |||||
| ~DropoutGradCPUKernel() override = default; | ~DropoutGradCPUKernel() override = default; | ||||
| @@ -36,6 +36,7 @@ class DropoutGradCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| float scale_; | float scale_; | ||||
| int thread_count_ = 1; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -29,36 +29,34 @@ using mindspore::schema::PrimitiveType_NegGrad; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| namespace { | namespace { | ||||
| int NegGradRun(void *cdata, int thread_id) { | |||||
| int NegGradRun(void *cdata, int task_id) { | |||||
| MS_ASSERT(cdata != nullptr); | MS_ASSERT(cdata != nullptr); | ||||
| auto kernel = reinterpret_cast<NegGradCPUKernel *>(cdata); | auto kernel = reinterpret_cast<NegGradCPUKernel *>(cdata); | ||||
| MS_ASSERT(kernel != nullptr); | MS_ASSERT(kernel != nullptr); | ||||
| return kernel->DoNegGrad(thread_id); | |||||
| return kernel->DoNegGrad(task_id); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| int NegGradCPUKernel::Init() { return RET_OK; } | int NegGradCPUKernel::Init() { return RET_OK; } | ||||
| int NegGradCPUKernel::DoNegGrad(int thread_id) { | |||||
| auto dy = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| auto dx = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||||
| int dy_size = in_tensors_.at(0)->ElementsNum(); | |||||
| int size = MSMIN(thread_stride_, static_cast<int>(dy_size - thread_id * thread_stride_)); | |||||
| if (size <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| int offset = thread_id * thread_stride_; | |||||
| ElementNegative(dy + offset, dx + offset, size); | |||||
| int NegGradCPUKernel::DoNegGrad(int task_id) { | |||||
| auto dy = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto dx = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| ElementNegative(dy + start, dx + start, count); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int NegGradCPUKernel::ReSize() { return RET_OK; } | int NegGradCPUKernel::ReSize() { return RET_OK; } | ||||
| int NegGradCPUKernel::Run() { | int NegGradCPUKernel::Run() { | ||||
| int dy_size = in_tensors_.at(0)->ElementsNum(); | |||||
| op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, static_cast<int>(dy_size)); | |||||
| thread_stride_ = UP_DIV(dy_size, op_parameter_->thread_num_); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, NegGradRun, this, op_parameter_->thread_num_); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, NegGradRun, this, thread_count_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "parallel launch fail!ret: " << ret; | MS_LOG(ERROR) << "parallel launch fail!ret: " << ret; | ||||
| return ret; | return ret; | ||||
| @@ -28,7 +28,7 @@ class NegGradCPUKernel : public LiteKernel { | |||||
| explicit NegGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit NegGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} | |||||
| ~NegGradCPUKernel() override {} | ~NegGradCPUKernel() override {} | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| @@ -36,7 +36,7 @@ class NegGradCPUKernel : public LiteKernel { | |||||
| int DoNegGrad(int thread_id); | int DoNegGrad(int thread_id); | ||||
| private: | private: | ||||
| int thread_stride_; | |||||
| int thread_count_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -29,7 +29,7 @@ using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_PoolingGrad; | using mindspore::schema::PrimitiveType_PoolingGrad; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int PoolingGradCPUKernel::Init() { | |||||
| int PoolingGradCPUKernel::ReSize() { | |||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | ||||
| auto in_shape = in_tensors_.at(0)->shape(); | auto in_shape = in_tensors_.at(0)->shape(); | ||||
| @@ -59,7 +59,7 @@ int PoolingGradCPUKernel::Init() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int PoolingGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int PoolingGradCPUKernel::Init() { return ReSize(); } | |||||
| int PoolingGradCPUKernel::Execute(int task_id) { | int PoolingGradCPUKernel::Execute(int task_id) { | ||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | ||||
| @@ -45,13 +45,20 @@ int PowerGradCPUKernel::Execute(int task_id) { | |||||
| auto dy_addr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto dy_addr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| auto x_addr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto x_addr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| auto dx_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto dx_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| auto size = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| float exp = power_ - 1; | float exp = power_ - 1; | ||||
| Power(x_addr, &exp, dx_addr, size, scale_, shift_, true); | |||||
| ElementMul(dx_addr, dy_addr, dx_addr, size); | |||||
| Power(&(x_addr[start]), &exp, &(dx_addr[start]), count, scale_, shift_, true); | |||||
| ElementMul(&(dx_addr[start]), &(dy_addr[start]), &(dx_addr[start]), count); | |||||
| float scale = scale_ * power_; | float scale = scale_ * power_; | ||||
| for (int i = 0; i < size; i++) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| dx_addr[i] *= scale; | dx_addr[i] *= scale; | ||||
| } | } | ||||
| @@ -69,7 +76,7 @@ int PowerGradRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int PowerGradCPUKernel::Run() { | int PowerGradCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,7 +27,7 @@ class PowerGradCPUKernel : public LiteKernel { | |||||
| PowerGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, | PowerGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(param, inputs, outputs, ctx, primitive) { | |||||
| : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { | |||||
| PowerParameter *power_param = reinterpret_cast<PowerParameter *>(param); | PowerParameter *power_param = reinterpret_cast<PowerParameter *>(param); | ||||
| power_ = power_param->power_; | power_ = power_param->power_; | ||||
| scale_ = power_param->scale_; | scale_ = power_param->scale_; | ||||
| @@ -41,6 +41,7 @@ class PowerGradCPUKernel : public LiteKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| int thread_count_; | |||||
| float power_; | float power_; | ||||
| float scale_; | float scale_; | ||||
| float shift_; | float shift_; | ||||
| @@ -16,6 +16,7 @@ | |||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32_grad/sgd.h" | #include "src/runtime/kernel/arm/fp32_grad/sgd.h" | ||||
| #include <algorithm> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -37,36 +38,42 @@ int SgdCPUKernel::Execute(int task_id) { | |||||
| float learning_rate = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData())[0]; | float learning_rate = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData())[0]; | ||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | ||||
| size_t elem_num = in_tensors_.at(0)->ElementsNum(); | |||||
| auto stat = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData()); | auto stat = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData()); | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| if (stat[0] > 0) { | |||||
| stat[0] = 0; | |||||
| memcpy(accumulate, gradient, elem_num * sizeof(float)); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| if (stat[task_id] > 0) { | |||||
| stat[task_id] = 0; // Haim Please approve this | |||||
| std::copy(&(gradient[start]), &(gradient[end]), &(accumulate[start])); | |||||
| if (sgd_param_->use_nesterov_) { | if (sgd_param_->use_nesterov_) { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| weight[i] -= accumulate[i] * learning_rate; | weight[i] -= accumulate[i] * learning_rate; | ||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (moment > 0.f) { | if (moment > 0.f) { | ||||
| if (sgd_param_->use_nesterov_) { | if (sgd_param_->use_nesterov_) { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); | accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); | ||||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); | accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); | ||||
| weight[i] -= accumulate[i] * learning_rate; | weight[i] -= accumulate[i] * learning_rate; | ||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| for (size_t i = start; i < end; ++i) { | |||||
| weight[i] -= gradient[i] * learning_rate; | weight[i] -= gradient[i] * learning_rate; | ||||
| } | } | ||||
| } | } | ||||
| @@ -85,7 +92,7 @@ int SgdRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int SgdCPUKernel::Run() { | int SgdCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -114,6 +121,17 @@ int SgdCPUKernel::Init() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int SgdCPUKernel::SetLearningRate(float lr) { | |||||
| auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| learning_rate_tensor[0] = lr; | |||||
| return RET_OK; | |||||
| } | |||||
| float SgdCPUKernel::GetLearningRate() { | |||||
| auto learning_rate_tensor = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| return learning_rate_tensor[0]; | |||||
| } | |||||
| kernel::LiteKernel *CpuSgdFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuSgdFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | ||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | const lite::InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| @@ -18,16 +18,18 @@ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SGD_H_ | #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SGD_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | |||||
| #include "src/train/optimizer_kernel.h" | |||||
| #include "nnacl/fp32_grad/optimizer.h" | #include "nnacl/fp32_grad/optimizer.h" | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class SgdCPUKernel : public LiteKernel { | |||||
| class SgdCPUKernel : public OptimizerKernel { | |||||
| public: | public: | ||||
| explicit SgdCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit SgdCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), sgd_param_(nullptr) { | |||||
| : OptimizerKernel(parameter, inputs, outputs, ctx, primitive), | |||||
| thread_count_(ctx->thread_num_), | |||||
| sgd_param_(nullptr) { | |||||
| sgd_param_ = reinterpret_cast<SgdParameter *>(parameter); | sgd_param_ = reinterpret_cast<SgdParameter *>(parameter); | ||||
| } | } | ||||
| ~SgdCPUKernel() override {} | ~SgdCPUKernel() override {} | ||||
| @@ -35,8 +37,11 @@ class SgdCPUKernel : public LiteKernel { | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| int SetLearningRate(float lr) override; | |||||
| float GetLearningRate() override; | |||||
| private: | private: | ||||
| int thread_count_; | |||||
| SgdParameter *sgd_param_; | SgdParameter *sgd_param_; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -35,12 +35,19 @@ int SmoothL1LossCPUKernel::Execute(int task_id) { | |||||
| auto target = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto target = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| const size_t tensor_len = in_tensors_.at(0)->ElementsNum(); | |||||
| const size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| const float zero = 0.0f; | const float zero = 0.0f; | ||||
| const float half = 0.5f; | const float half = 0.5f; | ||||
| const float beta = smooth_l1_loss_param->beta_; | const float beta = smooth_l1_loss_param->beta_; | ||||
| for (uint64_t i = 0; i < tensor_len; ++i) { | |||||
| for (uint64_t i = start; i < end; ++i) { | |||||
| float diff = predict[i] - target[i]; | float diff = predict[i] - target[i]; | ||||
| if (diff < zero) { | if (diff < zero) { | ||||
| diff = -diff; | diff = -diff; | ||||
| @@ -66,7 +73,7 @@ int SmoothL1LossRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int SmoothL1LossCPUKernel::Run() { | int SmoothL1LossCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "SmoothL1Loss function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "SmoothL1Loss function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,7 +27,9 @@ class SmoothL1LossCPUKernel : public LiteKernel { | |||||
| explicit SmoothL1LossCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit SmoothL1LossCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), smooth_l1_param_(nullptr) { | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), | |||||
| smooth_l1_param_(nullptr), | |||||
| thread_count_(ctx->thread_num_) { | |||||
| smooth_l1_param_ = reinterpret_cast<SmoothL1LossParameter *>(parameter); | smooth_l1_param_ = reinterpret_cast<SmoothL1LossParameter *>(parameter); | ||||
| } | } | ||||
| ~SmoothL1LossCPUKernel() override {} | ~SmoothL1LossCPUKernel() override {} | ||||
| @@ -38,6 +40,7 @@ class SmoothL1LossCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| SmoothL1LossParameter *smooth_l1_param_; | SmoothL1LossParameter *smooth_l1_param_; | ||||
| int thread_count_ = 1; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -36,10 +36,17 @@ int SmoothL1LossGradCPUKernel::Execute(int task_id) { | |||||
| auto d_loss = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | auto d_loss = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | ||||
| auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| const size_t tensor_len = in_tensors_.at(0)->ElementsNum(); | |||||
| const size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| const float beta = smooth_l1_loss_param->beta_; | const float beta = smooth_l1_loss_param->beta_; | ||||
| for (uint64_t i = 0; i < tensor_len; ++i) { | |||||
| for (uint64_t i = start; i < end; ++i) { | |||||
| float diff = predict[i] - target[i]; | float diff = predict[i] - target[i]; | ||||
| if (diff > beta) { | if (diff > beta) { | ||||
| out[i] = d_loss[i]; | out[i] = d_loss[i]; | ||||
| @@ -63,7 +70,7 @@ int SmoothL1LossGradRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int SmoothL1LossGradCPUKernel::Run() { | int SmoothL1LossGradCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossGradRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossGradRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "SmoothL1LossGrad function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "SmoothL1LossGrad function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,7 +27,9 @@ class SmoothL1LossGradCPUKernel : public LiteKernel { | |||||
| explicit SmoothL1LossGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit SmoothL1LossGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), smooth_l1_param_(nullptr) { | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), | |||||
| smooth_l1_param_(nullptr), | |||||
| thread_count_(ctx->thread_num_) { | |||||
| smooth_l1_param_ = reinterpret_cast<SmoothL1LossParameter *>(parameter); | smooth_l1_param_ = reinterpret_cast<SmoothL1LossParameter *>(parameter); | ||||
| } | } | ||||
| ~SmoothL1LossGradCPUKernel() override {} | ~SmoothL1LossGradCPUKernel() override {} | ||||
| @@ -38,6 +40,7 @@ class SmoothL1LossGradCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| SmoothL1LossParameter *smooth_l1_param_; | SmoothL1LossParameter *smooth_l1_param_; | ||||
| int thread_count_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_SoftmaxCrossEntropy; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; } | |||||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { return ReSize(); } | |||||
| void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *labels, const float *logits, float *grads, | void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *labels, const float *logits, float *grads, | ||||
| float *output2) const { | float *output2) const { | ||||
| @@ -100,7 +100,7 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { | |||||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { | |||||
| auto dims = in_tensors_.at(0)->shape(); | auto dims = in_tensors_.at(0)->shape(); | ||||
| param_->n_dim_ = 2; | param_->n_dim_ = 2; | ||||
| param_->number_of_classes_ = dims.at(1); | param_->number_of_classes_ = dims.at(1); | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | |||||
| #include "src/runtime/kernel/arm/fp32_grad/tuple_getitem.h" | #include "src/runtime/kernel/arm/fp32_grad/tuple_getitem.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| @@ -47,7 +48,15 @@ int TupleGetItemCPUKernel::Execute(int task_id) { | |||||
| auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| memcpy(out, in, in_tensors_.at(0)->Size()); | |||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| std::copy(&(in[start]), &(in[end]), &(out[start])); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -62,7 +71,7 @@ int TupleRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int TupleGetItemCPUKernel::Run() { | int TupleGetItemCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, 1); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -27,7 +27,7 @@ class TupleGetItemCPUKernel : public LiteKernel { | |||||
| explicit TupleGetItemCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit TupleGetItemCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const lite::PrimitiveC *primitive) | const lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { | |||||
| param = parameter; | param = parameter; | ||||
| } | } | ||||
| ~TupleGetItemCPUKernel() override = default; | ~TupleGetItemCPUKernel() override = default; | ||||
| @@ -38,6 +38,7 @@ class TupleGetItemCPUKernel : public LiteKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| int thread_count_ = 1; | |||||
| OpParameter *param; | OpParameter *param; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -0,0 +1,98 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "include/train/classification_train_accuracy_monitor.h" | |||||
| #include <sys/stat.h> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <iostream> | |||||
| #include <fstream> | |||||
| #include <memory> | |||||
| #include "include/errorcode.h" | |||||
| #include "include/train_session.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/tensor.h" | |||||
| #include "src/train/loss_kernel.h" | |||||
| #include "src/train/optimizer_kernel.h" | |||||
| #include "src/sub_graph_kernel.h" | |||||
| #include "src/train/train_populate_parameter.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "src/executor.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/convolution.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| void ClassificationTrainAccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) { | |||||
| if (cb_data.epoch_ == 0) accuracies_.clear(); | |||||
| } | |||||
| void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) { | |||||
| if (accuracies_.size() != cb_data.epoch_) { | |||||
| MS_LOG(WARNING) << "Accuracies array does not match epoch number"; | |||||
| } else { | |||||
| accuracies_.push_back(std::make_pair(cb_data.epoch_, 0.0)); | |||||
| } | |||||
| } | |||||
| int ClassificationTrainAccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) { | |||||
| if (cb_data.step_ > 0) accuracies_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_); | |||||
| if ((cb_data.epoch_ + 1) % print_every_n_ == 0) { | |||||
| std::cout << cb_data.epoch_ + 1 << ":\tTraining Accuracy is " << accuracies_.at(cb_data.epoch_).second << std::endl; | |||||
| } | |||||
| return mindspore::session::RET_CONTINUE; | |||||
| } | |||||
| void ClassificationTrainAccuracyMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) { | |||||
| auto inputs = cb_data.session_->GetInputs(); | |||||
| auto outputs = cb_data.session_->GetPredictions(); | |||||
| auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData()); | |||||
| for (auto it = outputs.begin(); it != outputs.end(); ++it) { | |||||
| if (it->second->ElementsNum() == inputs.at(1)->ElementsNum()) { | |||||
| int batch_size = inputs.at(1)->shape().at(0); | |||||
| int num_of_classes = inputs.at(1)->shape().at(1); | |||||
| auto predictions = reinterpret_cast<float *>(it->second->MutableData()); | |||||
| float accuracy = 0.0; | |||||
| for (int b = 0; b < batch_size; b++) { | |||||
| int label = 0; | |||||
| int max_idx = 0; | |||||
| float max_label_score = labels[num_of_classes * b]; | |||||
| float max_score = predictions[num_of_classes * b]; | |||||
| for (int c = 1; c < num_of_classes; c++) { | |||||
| if (predictions[num_of_classes * b + c] > max_score) { | |||||
| max_score = predictions[num_of_classes * b + c]; | |||||
| max_idx = c; | |||||
| } | |||||
| if (labels[num_of_classes * b + c] > max_label_score) { | |||||
| max_label_score = labels[num_of_classes * b + c]; | |||||
| label = c; | |||||
| } | |||||
| } | |||||
| if (label == max_idx) accuracy += 1.0; | |||||
| } | |||||
| accuracy /= static_cast<float>(batch_size); | |||||
| accuracies_.at(cb_data.epoch_).second = accuracy; | |||||
| return; | |||||
| } | |||||
| } | |||||
| MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1"; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "include/train/loss_monitor.h" | |||||
| #include <sys/stat.h> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <iostream> | |||||
| #include <fstream> | |||||
| #include <memory> | |||||
| #include "include/errorcode.h" | |||||
| #include "include/train_session.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/tensor.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| void LossMonitor::Begin(const session::TrainLoopCallBackData &cb_data) { | |||||
| if (cb_data.epoch_ == 0) losses_.clear(); | |||||
| } | |||||
| void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) { | |||||
| if (losses_.size() != cb_data.epoch_) { | |||||
| MS_LOG(WARNING) << "losses array does not match epoch number"; | |||||
| } else { | |||||
| losses_.push_back(std::make_pair(cb_data.epoch_, 0.0)); | |||||
| } | |||||
| } | |||||
| int LossMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) { | |||||
| if (cb_data.step_ > 0) losses_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_); | |||||
| if ((cb_data.epoch_ + 1) % print_every_n_ == 0) { | |||||
| std::cout << cb_data.epoch_ + 1 << ":\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl; | |||||
| } | |||||
| return mindspore::session::RET_CONTINUE; | |||||
| } | |||||
| void LossMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) { | |||||
| auto outputs = cb_data.session_->GetOutputs(); | |||||
| for (auto it = outputs.begin(); it != outputs.end(); ++it) { | |||||
| if (it->second->ElementsNum() == 1) { | |||||
| auto loss = reinterpret_cast<float *>(it->second->MutableData()); | |||||
| losses_.at(cb_data.epoch_).second += loss[0]; | |||||
| return; | |||||
| } | |||||
| } | |||||
| MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1"; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "include/train/lr_scheduler.h" | |||||
| #include <sys/stat.h> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <iostream> | |||||
| #include <fstream> | |||||
| #include <memory> | |||||
| #include "include/errorcode.h" | |||||
| #include "include/train_session.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/tensor.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| int MultiplicativeLRLambda(float *lr, int epoch, void *lr_cb_data) { | |||||
| if ((lr == nullptr) || (lr_cb_data == nullptr)) { | |||||
| MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda"; | |||||
| return DONT_UPDATE_LR; | |||||
| } | |||||
| float mult = *(static_cast<float *>(lr_cb_data)); | |||||
| *lr = *lr * mult; | |||||
| return UPDATE_LR; | |||||
| } | |||||
| int StepLRLambda(float *lr, int epoch, void *lr_cb_data) { | |||||
| if ((lr == nullptr) || (lr_cb_data == nullptr)) { | |||||
| MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda"; | |||||
| return DONT_UPDATE_LR; | |||||
| } | |||||
| struct StepLRLambda *step_lr_data = (static_cast<struct StepLRLambda *>(lr_cb_data)); | |||||
| if (((epoch + 1) % step_lr_data->step_size) == 0) { | |||||
| *lr = *lr * step_lr_data->gamma; | |||||
| return UPDATE_LR; | |||||
| } | |||||
| return DONT_UPDATE_LR; | |||||
| } | |||||
| LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step) | |||||
| : lambda_func_(lambda_func), lr_data_(lr_cb_data), step_(step) {} | |||||
| int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) { | |||||
| if (((cb_data.epoch_ + 1) % step_) == 0) { | |||||
| float lr = cb_data.session_->GetLearningRate(); | |||||
| int update = lambda_func_(&lr, cb_data.epoch_, lr_data_); | |||||
| if (update == UPDATE_LR) { | |||||
| int ret = cb_data.session_->SetLearningRate(lr); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Error setting Leraning rate in train session"; | |||||
| return mindspore::session::RET_EXIT; | |||||
| } | |||||
| } | |||||
| } | |||||
| return mindspore::session::RET_CONTINUE; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ | |||||
| #define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class OptimizerKernel : public LiteKernel { | |||||
| public: | |||||
| OptimizerKernel() = default; | |||||
| OptimizerKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~OptimizerKernel() = default; | |||||
| virtual int SetLearningRate(float lr) = 0; | |||||
| virtual float GetLearningRate() = 0; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ | |||||
| @@ -0,0 +1,99 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/train/train_loop.h" | |||||
| #include <sys/stat.h> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <iostream> | |||||
| #include <fstream> | |||||
| #include <memory> | |||||
| #include "include/errorcode.h" | |||||
| #include "include/train_session.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/tensor.h" | |||||
| #include "src/train/loss_kernel.h" | |||||
| #include "src/train/optimizer_kernel.h" | |||||
| #include "src/sub_graph_kernel.h" | |||||
| #include "src/train/train_populate_parameter.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "src/executor.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/fp32_grad/convolution.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| using session::RET_CONTINUE; | |||||
| using session::RET_EXIT; | |||||
| using session::RET_STOP_TRAINING; | |||||
| TrainLoop::~TrainLoop() { | |||||
| if (train_session_ != nullptr) delete train_session_; | |||||
| } | |||||
| int TrainLoop::Train(int epochs, std::vector<session::TrainLoopCallBack *> cbs) { | |||||
| train_session_->Train(); | |||||
| session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this); | |||||
| for (auto cb : cbs) cb->Begin(cb_data); | |||||
| int steps_in_epoch = 1; // should be data_size/batch_size | |||||
| for (int i = 0; i < epochs; i++) { | |||||
| cb_data.epoch_ = epoch_++; | |||||
| for (auto cb : cbs) cb->EpochBegin(cb_data); | |||||
| for (int s = 0; s < steps_in_epoch; s++) { | |||||
| cb_data.step_ = s; | |||||
| for (auto cb : cbs) cb->StepBegin(cb_data); | |||||
| train_session_->RunGraph(before_cb_, after_cb_); | |||||
| for (auto cb : cbs) cb->StepEnd(cb_data); | |||||
| } | |||||
| int break_loop = false; | |||||
| for (auto cb : cbs) { | |||||
| int ret = cb->EpochEnd(cb_data); | |||||
| if (ret != RET_CONTINUE) { | |||||
| if (ret == RET_EXIT) { | |||||
| MS_LOG(ERROR) << "Error in TrainLoop callback"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (ret == RET_STOP_TRAINING) { | |||||
| break_loop = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (break_loop) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| for (auto cb : cbs) cb->End(cb_data); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| session::TrainLoop *session::TrainLoop::CreateTrainLoop(const std::string &model_filename, lite::Context *context, | |||||
| int batch_size) { | |||||
| auto train_session = session::TrainSession::CreateSession(model_filename, context); | |||||
| auto loop = new (std::nothrow) lite::TrainLoop(train_session); | |||||
| return loop; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_ | |||||
| #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <unordered_map> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "include/train/train_loop.h" | |||||
| #include "include/train_session.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TrainLoop : virtual public session::TrainLoop { | |||||
| public: | |||||
| explicit TrainLoop(session::TrainSession *session) : train_session_(session) {} | |||||
| session::TrainSession *train_session() override { return train_session_; } | |||||
| int Reset() override { | |||||
| epoch_ = 0; | |||||
| return RET_OK; | |||||
| } | |||||
| virtual ~TrainLoop(); | |||||
| int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) override { | |||||
| before_cb_ = before; | |||||
| after_cb_ = after; | |||||
| return RET_OK; | |||||
| } | |||||
| int Train(int epochs, std::vector<session::TrainLoopCallBack *> cbs) override; | |||||
| protected: | |||||
| session::TrainSession *train_session_ = nullptr; | |||||
| unsigned int epoch_ = 0; | |||||
| KernelCallBack before_cb_ = nullptr; | |||||
| KernelCallBack after_cb_ = nullptr; | |||||
| int batch_size; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_ | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/train/train_populate_parameter.h" | #include "src/train/train_populate_parameter.h" | ||||
| #include <algorithm> | |||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "src/ops/pooling_grad.h" | #include "src/ops/pooling_grad.h" | ||||
| #include "nnacl/pooling_parameter.h" | #include "nnacl/pooling_parameter.h" | ||||
| @@ -517,12 +518,15 @@ OpParameter *PopulateArithmeticGradParameter(const mindspore::lite::PrimitiveC * | |||||
| arithmetic_param->broadcasting_ = ((lite::ArithmeticGrad *)primitive)->Broadcasting(); | arithmetic_param->broadcasting_ = ((lite::ArithmeticGrad *)primitive)->Broadcasting(); | ||||
| arithmetic_param->ndim_ = ((lite::ArithmeticGrad *)primitive)->NDims(); | arithmetic_param->ndim_ = ((lite::ArithmeticGrad *)primitive)->NDims(); | ||||
| auto tmp_shape = ((lite::ArithmeticGrad *)primitive)->x1Shape(); | |||||
| memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::ArithmeticGrad *)primitive)->x2Shape(); | |||||
| memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| tmp_shape = ((lite::ArithmeticGrad *)primitive)->dyShape(); | |||||
| memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); | |||||
| auto shape = ((lite::ArithmeticGrad *)primitive)->x1Shape(); | |||||
| auto source = static_cast<int *>(shape.data()); | |||||
| std::copy(source, source + shape.size(), arithmetic_param->in_shape0_); | |||||
| shape = ((lite::ArithmeticGrad *)primitive)->x2Shape(); | |||||
| source = static_cast<int *>(shape.data()); | |||||
| std::copy(source, source + shape.size(), arithmetic_param->in_shape1_); | |||||
| shape = ((lite::ArithmeticGrad *)primitive)->dyShape(); | |||||
| source = static_cast<int *>(shape.data()); | |||||
| std::copy(source, source + shape.size(), arithmetic_param->out_shape_); | |||||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | return reinterpret_cast<OpParameter *>(arithmetic_param); | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/train/loss_kernel.h" | #include "src/train/loss_kernel.h" | ||||
| #include "src/train/optimizer_kernel.h" | |||||
| #include "src/sub_graph_kernel.h" | #include "src/sub_graph_kernel.h" | ||||
| #include "src/train/train_populate_parameter.h" | #include "src/train/train_populate_parameter.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| @@ -49,10 +50,8 @@ TrainSession::TrainSession() { kernel::PopulateTrainParameters(); } | |||||
| std::vector<CreatorOp> TrainSession::ReplaceOps() { | std::vector<CreatorOp> TrainSession::ReplaceOps() { | ||||
| const std::vector<CreatorOp> replace = { | const std::vector<CreatorOp> replace = { | ||||
| {{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Conv2D}, | |||||
| mindspore::kernel::CpuConvTrainFp32KernelCreator}, | |||||
| {{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_DepthwiseConv2D}, | |||||
| mindspore::kernel::CpuConvTrainFp32KernelCreator}}; | |||||
| // currently no ops are Hijacked by TrainSession | |||||
| }; | |||||
| mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance(); | mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance(); | ||||
| std::vector<CreatorOp> results; | std::vector<CreatorOp> results; | ||||
| for (auto v : replace) { | for (auto v : replace) { | ||||
| @@ -98,7 +97,7 @@ int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { | |||||
| RestoreOps(restore); | RestoreOps(restore); | ||||
| CompileTrainKernels(); // Prepare a list of train kernels | CompileTrainKernels(); // Prepare a list of train kernels | ||||
| CompileInferenceKernels(); // Prepare a list of eval kernels | CompileInferenceKernels(); // Prepare a list of eval kernels | ||||
| CompileOptimizedKernels(); // Prepare a list of kenels which are optimized (weight update step) | |||||
| CompileOptimizedKernels(); // Prepare a list of kernels which are optimized (weight update step) | |||||
| CompileTrainOutputs(); // prepare outputs in train mode | CompileTrainOutputs(); // prepare outputs in train mode | ||||
| CompileEvalOutputs(); // prepare outputs in eval mode | CompileEvalOutputs(); // prepare outputs in eval mode | ||||
| AllocWorkSpace(); | AllocWorkSpace(); | ||||
| @@ -302,6 +301,30 @@ void TrainSession::CompileOptimizedKernels() { | |||||
| } | } | ||||
| } | } | ||||
| int TrainSession::SetLearningRate(float learning_rate) { | |||||
| for (auto kernel : this->train_kernels_) { | |||||
| if (IsOptimizer(kernel)) { | |||||
| auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel); | |||||
| auto ret = optimizer->SetLearningRate(learning_rate); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << kernel->name() << " failed to set learning rate"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| float TrainSession::GetLearningRate() { | |||||
| for (auto kernel : this->train_kernels_) { | |||||
| if (IsOptimizer(kernel)) { | |||||
| auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel); | |||||
| return optimizer->GetLearningRate(); | |||||
| } | |||||
| } | |||||
| return 0.0; | |||||
| } | |||||
| bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const { | bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const { | ||||
| return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropy || | return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropy || | ||||
| kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropy || | kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropy || | ||||
| @@ -42,7 +42,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>; | using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>; | ||||
| class TrainSession : virtual public session::TrainSession, virtual public lite::LiteSession { | class TrainSession : virtual public session::TrainSession, virtual public lite::LiteSession { | ||||
| public: | public: | ||||
| @@ -59,6 +58,8 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: | |||||
| int Train() override; | int Train() override; | ||||
| int Eval() override; | int Eval() override; | ||||
| int SetLearningRate(float learning_rate) override; | |||||
| float GetLearningRate() override; | |||||
| void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); } | void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); } | ||||
| std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); } | std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); } | ||||
| @@ -80,6 +81,10 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetPredictions() const override { | |||||
| return eval_output_tensor_map_; | |||||
| } | |||||
| protected: | protected: | ||||
| void AllocWorkSpace(); | void AllocWorkSpace(); | ||||
| bool IsLossKernel(const kernel::LiteKernel *kernel) const; | bool IsLossKernel(const kernel::LiteKernel *kernel) const; | ||||
| @@ -1,11 +1,13 @@ | |||||
| mini_alexnet | mini_alexnet | ||||
| #mobilenetv1 | |||||
| # mobilenetv1 | |||||
| mobilenetv2 | mobilenetv2 | ||||
| mobilenetv3 | mobilenetv3 | ||||
| lenet | lenet | ||||
| effnet | effnet | ||||
| effnet_tune | |||||
| # effnet_tune | |||||
| # lenetv1 | # lenetv1 | ||||
| # resnet | # resnet | ||||
| # effnetv1 | |||||
| # googlenet | |||||
| # densenet | |||||
| # one_net | |||||
| #LAST | #LAST | ||||
| @@ -83,7 +83,7 @@ function Run_x86() { | |||||
| --inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \ | --inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \ | ||||
| --expectedDataFile=${train_io_path}/${model_name}_outputs.bin \ | --expectedDataFile=${train_io_path}/${model_name}_outputs.bin \ | ||||
| --exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \ | --exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \ | ||||
| --epochs=${epoch_num} | |||||
| --epochs=${epoch_num} --numThreads=${threads} | |||||
| if [ $? = 0 ]; then | if [ $? = 0 ]; then | ||||
| run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} | run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} | ||||
| else | else | ||||
| @@ -178,7 +178,8 @@ function Run_arm() { | |||||
| --modelFile=${model_name}_train.ms \ | --modelFile=${model_name}_train.ms \ | ||||
| --inDataFile=${tmp_dir}/${model_name}_input1.bin,${tmp_dir}/${model_name}_input2.bin \ | --inDataFile=${tmp_dir}/${model_name}_input1.bin,${tmp_dir}/${model_name}_input2.bin \ | ||||
| --expectedDataFile=${tmp_dir}/${model_name}_outputs.bin \ | --expectedDataFile=${tmp_dir}/${model_name}_outputs.bin \ | ||||
| --exportFile=${tmp_dir}/${model_name}_train_exported.ms | |||||
| --exportFile=${tmp_dir}/${model_name}_train_exported.ms \ | |||||
| --numThreads=${threads} | |||||
| ENDM | ENDM | ||||
| ) | ) | ||||
| echo "${adb_cmd}" >> ${run_arm_log_file} | echo "${adb_cmd}" >> ${run_arm_log_file} | ||||
| @@ -221,8 +222,9 @@ echo ${basepath} | |||||
| # Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" | # Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" | ||||
| # For running on arm64, use -t to set platform tools path (for using adb commands) | # For running on arm64, use -t to set platform tools path (for using adb commands) | ||||
| epoch_num=1 | epoch_num=1 | ||||
| threads=1 | |||||
| train_io_path="" | train_io_path="" | ||||
| while getopts "r:m:d:i:e:vt:" opt; do | |||||
| while getopts "r:m:d:i:e:vt:q:" opt; do | |||||
| case ${opt} in | case ${opt} in | ||||
| r) | r) | ||||
| release_path=${OPTARG} | release_path=${OPTARG} | ||||
| @@ -249,9 +251,13 @@ while getopts "r:m:d:i:e:vt:" opt; do | |||||
| run_valgrind="valgrind --log-file=valgrind.log " | run_valgrind="valgrind --log-file=valgrind.log " | ||||
| echo "Run x86 with valgrind" | echo "Run x86 with valgrind" | ||||
| ;; | ;; | ||||
| q) | |||||
| threads=${OPTARG} | |||||
| echo "threads=${threads}" | |||||
| ;; | |||||
| t) | t) | ||||
| epoch_num=${OPTARG} | epoch_num=${OPTARG} | ||||
| echo "train epoch num is ${OPTARG}" | |||||
| echo "train epoch num is ${epoch_num}" | |||||
| ;; | ;; | ||||
| ?) | ?) | ||||
| echo "unknown para" | echo "unknown para" | ||||
| @@ -511,6 +511,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano | |||||
| auto valueNode = input_anode->cast<ValueNodePtr>(); | auto valueNode = input_anode->cast<ValueNodePtr>(); | ||||
| auto paramTensor = std::make_unique<schema::TensorT>(); | auto paramTensor = std::make_unique<schema::TensorT>(); | ||||
| auto value = valueNode->value(); | auto value = valueNode->value(); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| paramTensor->name = valueNode->fullname_with_scope(); | |||||
| #endif | |||||
| if (value->isa<tensor::Tensor>()) { | if (value->isa<tensor::Tensor>()) { | ||||
| auto valueAbstract = valueNode->abstract(); | auto valueAbstract = valueNode->abstract(); | ||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); | auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); | ||||
| @@ -527,7 +530,6 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano | |||||
| paramTensor->dims = dims; | paramTensor->dims = dims; | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| if (paramTensor->dims.size() == 0) paramTensor->dims = {1}; | if (paramTensor->dims.size() == 0) paramTensor->dims = {1}; | ||||
| paramTensor->name = valueNode->fullname_with_scope(); | |||||
| #endif | #endif | ||||
| paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | ||||
| auto data = value->cast<tensor::TensorPtr>(); | auto data = value->cast<tensor::TensorPtr>(); | ||||
| @@ -135,6 +135,7 @@ int NetTrain::ReadCalibData() { | |||||
| MS_LOG(INFO) << "Start reading calibData file"; | MS_LOG(INFO) << "Start reading calibData file"; | ||||
| std::string tensor_name; | std::string tensor_name; | ||||
| while (!in_file.eof()) { | while (!in_file.eof()) { | ||||
| getline(in_file, line); | getline(in_file, line); | ||||
| std::stringstream string_line1(line); | std::stringstream string_line1(line); | ||||
| @@ -189,7 +190,6 @@ int NetTrain::CompareOutput() { | |||||
| MS_ASSERT(tensor->MutableData() != nullptr); | MS_ASSERT(tensor->MutableData() != nullptr); | ||||
| auto outputs = tensor->MutableData(); | auto outputs = tensor->MutableData(); | ||||
| float bias = CompareData<float>(node_or_tensor_name, tensor->shape(), reinterpret_cast<float *>(outputs)); | float bias = CompareData<float>(node_or_tensor_name, tensor->shape(), reinterpret_cast<float *>(outputs)); | ||||
| if (bias >= 0) { | if (bias >= 0) { | ||||
| total_bias += bias; | total_bias += bias; | ||||
| total_size++; | total_size++; | ||||
| @@ -228,7 +228,7 @@ int NetTrain::CompareOutput() { | |||||
| int NetTrain::MarkPerformance() { | int NetTrain::MarkPerformance() { | ||||
| MS_LOG(INFO) << "Running train loops..."; | MS_LOG(INFO) << "Running train loops..."; | ||||
| std::cout << "Running train loops..." << std::endl; | std::cout << "Running train loops..." << std::endl; | ||||
| uint64_t time_min = 1000000; | |||||
| uint64_t time_min = 0xFFFFFFFFFFFFFFFF; | |||||
| uint64_t time_max = 0; | uint64_t time_max = 0; | ||||
| uint64_t time_avg = 0; | uint64_t time_avg = 0; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ | |||||
| #include <getopt.h> | #include <getopt.h> | ||||
| #include <signal.h> | #include <signal.h> | ||||
| @@ -59,6 +59,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||||
| AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0); | AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0); | ||||
| AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false); | AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false); | ||||
| AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1); | AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1); | ||||
| AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1); | |||||
| // MarkAccuracy | // MarkAccuracy | ||||
| AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); | AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); | ||||
| AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); | AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); | ||||
| @@ -239,4 +240,4 @@ class MS_API NetTrain { | |||||
| int MS_API RunNetTrain(int argc, const char **argv); | int MS_API RunNetTrain(int argc, const char **argv); | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ | |||||
| @@ -136,10 +136,10 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT | |||||
| static const std::vector<schema::PrimitiveType> needInsertOpList = { | static const std::vector<schema::PrimitiveType> needInsertOpList = { | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split, | |||||
| schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, | |||||
| schema::PrimitiveType_Add | |||||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split, | |||||
| schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add, | |||||
| schema::PrimitiveType_ActivationGrad | |||||
| #else | #else | ||||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | ||||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, | schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, | ||||