| @@ -16,7 +16,7 @@ CFLAGS := -Ofast -std=c++17 \ | |||||
| -I . \ | -I . \ | ||||
| -I ./msl/train \ | -I ./msl/train \ | ||||
| -I ./msl/train/minddata \ | -I ./msl/train/minddata \ | ||||
| -I ./msl/train/third_party/flatbuffers/include | |||||
| -I ./msl/tools/third_party/flatbuffers/include | |||||
| ifeq ($(TARGET),arm64) | ifeq ($(TARGET),arm64) | ||||
| @@ -79,15 +79,17 @@ cp model/*.ms ${PACKAGE}/model || exit 1 | |||||
| cp scripts/*.sh ${PACKAGE}/ | cp scripts/*.sh ${PACKAGE}/ | ||||
| # Copy the shared MindSpore ToD library | # Copy the shared MindSpore ToD library | ||||
| tar -xzf ${TARBALL} | |||||
| tar -xzf ${TARBALL} | |||||
| mv mindspore-*/train/lib ${PACKAGE}/ | mv mindspore-*/train/lib ${PACKAGE}/ | ||||
| mv mindspore-*/train/minddata/lib/* ${PACKAGE}/lib/ | mv mindspore-*/train/minddata/lib/* ${PACKAGE}/lib/ | ||||
| mv mindspore-*/train/minddata/third_party/libjpeg-turbo/lib/* ${PACKAGE}/lib/ | mv mindspore-*/train/minddata/third_party/libjpeg-turbo/lib/* ${PACKAGE}/lib/ | ||||
| if [ "${TARGET}" == "arm64" ]; then | |||||
| tar -xzf ${TARBALL} --wildcards --no-anchored hiai_ddk | |||||
| mv mindspore-*/train/third_party/hiai_ddk/lib/* ${PACKAGE}/lib/ | |||||
| fi | |||||
| rm -rf msl | rm -rf msl | ||||
| mkdir msl | |||||
| mv mindspore-*/* msl/ | |||||
| rm -rf mindspore-* | |||||
| mv mindspore-* msl/ | |||||
| # Copy the dataset to the package | # Copy the dataset to the package | ||||
| cp -r $MNIST_DATA_PATH ${PACKAGE}/dataset || exit 1 | cp -r $MNIST_DATA_PATH ${PACKAGE}/dataset || exit 1 | ||||
| @@ -101,7 +101,7 @@ void NetRunner::InitAndFigureInputs() { | |||||
| session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context); | session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context); | ||||
| MS_ASSERT(nullptr != session_); | MS_ASSERT(nullptr != session_); | ||||
| loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_, &context); | |||||
| loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_); | |||||
| acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics); | acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics); | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_loop.h" | #include "include/train/train_loop.h" | ||||
| #include "include/train/accuracy_metrics.h" | #include "include/train/accuracy_metrics.h" | ||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| @@ -1,32 +1,39 @@ | |||||
| BASE_DIR=$(realpath ../../../../) | BASE_DIR=$(realpath ../../../../) | ||||
| APP:=bin/net_runner | APP:=bin/net_runner | ||||
| MSLIB:=mindspore-lite | MSLIB:=mindspore-lite | ||||
| LMDLIB:=-lminddata-lite | |||||
| MSDIR:=$(realpath package-$(TARGET)/lib) | MSDIR:=$(realpath package-$(TARGET)/lib) | ||||
| ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") | |||||
| LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai | |||||
| else | |||||
| LHIAILIB:= | |||||
| endif | |||||
| SRC:=src/net_runner.cc src/dataset.cc | |||||
| SRC:=src/net_runner.cc src/dataset.cc | |||||
| OBJ:=$(SRC:.cc=.o) | OBJ:=$(SRC:.cc=.o) | ||||
| CFLAGS := -Ofast -std=c++17 \ | CFLAGS := -Ofast -std=c++17 \ | ||||
| -I . \ | -I . \ | ||||
| -I ./msl/train \ | -I ./msl/train \ | ||||
| -I ./msl/train/third_party/flatbuffers/include | |||||
| -I ./msl/train/minddata \ | |||||
| -I ./msl/tools/third_party/flatbuffers/include | |||||
| ifeq ($(TARGET),arm64) | ifeq ($(TARGET),arm64) | ||||
| CXX := ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++ | CXX := ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++ | ||||
| CFLAGS += --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -fdata-sections -ffunction-sections | CFLAGS += --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -fdata-sections -ffunction-sections | ||||
| LDFLAGS := --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -Wl,--gc-sections | LDFLAGS := --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -Wl,--gc-sections | ||||
| LDFLAGS += -L$(MSDIR) -l$(MSLIB) -pthread -llog -latomic -lm -Wl,-rpath,$(MSDIR) | |||||
| LDFLAGS += -L$(MSDIR) -l$(MSLIB) $(LMDLIB) $(LHIAILIB) -pthread -llog -latomic -lm -Wl,-rpath,$(MSDIR) | |||||
| else | else | ||||
| CFLAGS += -g | CFLAGS += -g | ||||
| LDFLAGS := -L$(MSDIR) -l$(MSLIB) -lpthread -Wl,-rpath,$(MSDIR) | |||||
| LDFLAGS := -L$(MSDIR) -l$(MSLIB) $(LMDLIB) $(LHIAILIB) -lpthread -Wl,-rpath,$(MSDIR) | |||||
| endif | endif | ||||
| LD := ${CXX} | LD := ${CXX} | ||||
| all:$(APP) | all:$(APP) | ||||
| $(APP): $(OBJ) $(MSDIR)/lib$(MSLIB).so | |||||
| $(APP): $(OBJ) | |||||
| @mkdir -p bin | @mkdir -p bin | ||||
| $(LD) $(OBJ) $(LDFLAGS) -o $@ | $(LD) $(OBJ) $(LDFLAGS) -o $@ | ||||
| @@ -8,6 +8,7 @@ fi | |||||
| echo "============Exporting==========" | echo "============Exporting==========" | ||||
| if [ -n "$1" ]; then | if [ -n "$1" ]; then | ||||
| DOCKER_IMG=$1 | DOCKER_IMG=$1 | ||||
| rm *.so* | |||||
| docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "python transfer_learning_export.py; chmod 444 transfer_learning_tod*.mindir; rm -rf __pycache__" | docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "python transfer_learning_export.py; chmod 444 transfer_learning_tod*.mindir; rm -rf __pycache__" | ||||
| else | else | ||||
| echo "MindSpore docker was not provided, attempting to run locally" | echo "MindSpore docker was not provided, attempting to run locally" | ||||
| @@ -49,7 +49,7 @@ HEAD.weight.set_data(M.Tensor(np.random.normal( | |||||
| 0, 0.1, HEAD.weight.data.shape).astype("float32"))) | 0, 0.1, HEAD.weight.data.shape).astype("float32"))) | ||||
| HEAD.bias.set_data(M.Tensor(np.zeros(HEAD.bias.data.shape, dtype="float32"))) | HEAD.bias.set_data(M.Tensor(np.zeros(HEAD.bias.data.shape, dtype="float32"))) | ||||
| sgd = M.nn.SGD(HEAD.trainable_params(), learning_rate=0.01, momentum=0.9, | |||||
| sgd = M.nn.SGD(HEAD.trainable_params(), learning_rate=0.015, momentum=0.9, | |||||
| dampening=0.01, weight_decay=0.0, nesterov=False, loss_scale=1.0) | dampening=0.01, weight_decay=0.0, nesterov=False, loss_scale=1.0) | ||||
| net = TrainWrap(HEAD, optimizer=sgd) | net = TrainWrap(HEAD, optimizer=sgd) | ||||
| backbone_out = M.Tensor(np.zeros([BATCH_SIZE, 1000]).astype(np.float32)) | backbone_out = M.Tensor(np.zeros([BATCH_SIZE, 1000]).astype(np.float32)) | ||||
| @@ -82,10 +82,13 @@ tar -xzf ${TARBALL} | |||||
| mv mindspore-*/train/lib ${PACKAGE}/ | mv mindspore-*/train/lib ${PACKAGE}/ | ||||
| mv mindspore-*/train/minddata/lib/* ${PACKAGE}/lib/ | mv mindspore-*/train/minddata/lib/* ${PACKAGE}/lib/ | ||||
| mv mindspore-*/train/minddata/third_party/libjpeg-turbo/lib/* ${PACKAGE}/lib/ | mv mindspore-*/train/minddata/third_party/libjpeg-turbo/lib/* ${PACKAGE}/lib/ | ||||
| if [ "${TARGET}" == "arm64" ]; then | |||||
| tar -xzf ${TARBALL} --wildcards --no-anchored hiai_ddk | |||||
| mv mindspore-*/train/third_party/hiai_ddk/lib/* ${PACKAGE}/lib/ | |||||
| fi | |||||
| rm -rf msl | rm -rf msl | ||||
| mkdir msl | |||||
| mv mindspore-*/* msl/ | |||||
| rm -rf mindspore-* | |||||
| mv mindspore-* msl/ | |||||
| # Convert the dataset into the package | # Convert the dataset into the package | ||||
| ./prepare_dataset.sh ${PLACES_DATA_PATH} || exit 1 | ./prepare_dataset.sh ${PLACES_DATA_PATH} || exit 1 | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| #include "src/dataset.h" | #include "src/dataset.h" | ||||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| constexpr int METRICS_CLASSIFICATION = 0; | constexpr int METRICS_CLASSIFICATION = 0; | ||||
| constexpr int METRICS_MULTILABLE = 1; | |||||
| constexpr int METRICS_MULTILABEL = 1; | |||||
| class AccuracyMetrics : public Metrics { | class AccuracyMetrics : public Metrics { | ||||
| public: | public: | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "include/train/train_loop_callback.h" | #include "include/train/train_loop_callback.h" | ||||
| #include "include/train/metrics.h" | #include "include/train/metrics.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class MSTensor; | class MSTensor; | ||||
| @@ -41,10 +41,9 @@ class TrainLoop { | |||||
| /// \brief Static method to create a TrainLoop object | /// \brief Static method to create a TrainLoop object | ||||
| /// | /// | ||||
| /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API | /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API | ||||
| /// \param[in] context Defines the context of the session to be created | |||||
| /// | /// | ||||
| /// \return Pointer of MindSpore Lite TrainLoop | /// \return Pointer of MindSpore Lite TrainLoop | ||||
| static TrainLoop *CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, int batch_size = -1); | |||||
| static TrainLoop *CreateTrainLoop(session::TrainSession *train_session); | |||||
| /// \brief Class destructor | /// \brief Class destructor | ||||
| virtual ~TrainLoop() = default; | virtual ~TrainLoop() = default; | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include <jni.h> | #include <jni.h> | ||||
| #include "common/ms_log.h" | #include "common/ms_log.h" | ||||
| #include "common/jni_utils.h" | #include "common/jni_utils.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createSession(JNIEnv *env, jobject thiz, | extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createSession(JNIEnv *env, jobject thiz, | ||||
| @@ -55,32 +55,33 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||||
| size_t start = stride * task_id; | size_t start = stride * task_id; | ||||
| auto error_code = RET_OK; | auto error_code = RET_OK; | ||||
| if (param_act_grad_->type_ == schema::ActivationType_RELU) { | |||||
| error_code = ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_RELU6) { | |||||
| error_code = Relu6Grad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) { | |||||
| error_code = LReluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | |||||
| // Sigmoid gets the input tensors in reverse order! | |||||
| error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | |||||
| error_code = TanhGrad(input_addr + start, yt_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) { | |||||
| error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { | |||||
| error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_ELU) { | |||||
| error_code = EluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_GELU) { | |||||
| error_code = GeluGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Activation type error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (error_code != RET_OK) { | |||||
| return RET_ERROR; | |||||
| if (count > 0) { | |||||
| if (param_act_grad_->type_ == schema::ActivationType_RELU) { | |||||
| error_code = ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_RELU6) { | |||||
| error_code = Relu6Grad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) { | |||||
| error_code = LReluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | |||||
| // Sigmoid gets the input tensors in reverse order! | |||||
| error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | |||||
| error_code = TanhGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) { | |||||
| error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { | |||||
| error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_ELU) { | |||||
| error_code = EluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_GELU) { | |||||
| error_code = GeluGrad(yt_addr + start, input_addr + start, count, output_addr + start); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Activation type error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (error_code != RET_OK) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -32,8 +32,8 @@ namespace mindspore::kernel { | |||||
| int AdamCPUKernel::ReSize() { return RET_OK; } | int AdamCPUKernel::ReSize() { return RET_OK; } | ||||
| int DoAdam(float *m, float *v, float *gradient, float *weight, float beta1, float beta2, float beta1_power, | |||||
| float beta2_power, float eps, float learning_rate, bool nesterov, size_t start, size_t end) { | |||||
| static int DoAdam(float *m, float *v, float *gradient, float *weight, float beta1, float beta2, float beta1_power, | |||||
| float beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) { | |||||
| if ((1.f - beta1_power) <= 0.0f) { | if ((1.f - beta1_power) <= 0.0f) { | ||||
| MS_LOG(ERROR) << "divisor cannot be 0 or below"; | MS_LOG(ERROR) << "divisor cannot be 0 or below"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -47,13 +47,13 @@ int DoAdam(float *m, float *v, float *gradient, float *weight, float beta1, floa | |||||
| const float one_minus_beta1 = 1.f - beta1; | const float one_minus_beta1 = 1.f - beta1; | ||||
| const float one_minus_beta2 = 1.f - beta2; | const float one_minus_beta2 = 1.f - beta2; | ||||
| if (nesterov) { // Nadam | if (nesterov) { // Nadam | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| m[i] += (gradient[i] - m[i]) * one_minus_beta1; | m[i] += (gradient[i] - m[i]) * one_minus_beta1; | ||||
| v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; | v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; | ||||
| weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (std::sqrt(v[i]) + eps); | weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (std::sqrt(v[i]) + eps); | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| m[i] += (gradient[i] - m[i]) * one_minus_beta1; | m[i] += (gradient[i] - m[i]) * one_minus_beta1; | ||||
| v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; | v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; | ||||
| weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps); | weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps); | ||||
| @@ -77,7 +77,6 @@ int AdamCPUKernel::Execute(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| @@ -30,15 +30,15 @@ using mindspore::schema::PrimitiveType_ApplyMomentum; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ApplyMomentumCPUKernel::ReSize() { return RET_OK; } | int ApplyMomentumCPUKernel::ReSize() { return RET_OK; } | ||||
| int DoApplyMomentum(float *weight, float *accumulate, float learning_rate, float *gradient, float moment, bool nesterov, | |||||
| size_t start, size_t end) { | |||||
| static int DoApplyMomentum(float *weight, float *accumulate, float learning_rate, float *gradient, float moment, | |||||
| bool nesterov, int start, int end) { | |||||
| if (nesterov) { | if (nesterov) { | ||||
| for (size_t i = start; i < end; i++) { | |||||
| for (int i = start; i < end; i++) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | accumulate[i] = accumulate[i] * moment + gradient[i]; | ||||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = start; i < end; i++) { | |||||
| for (int i = start; i < end; i++) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | accumulate[i] = accumulate[i] * moment + gradient[i]; | ||||
| weight[i] -= accumulate[i] * learning_rate; | weight[i] -= accumulate[i] * learning_rate; | ||||
| } | } | ||||
| @@ -56,6 +56,7 @@ int ApplyMomentumCPUKernel::Execute(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| @@ -72,7 +72,9 @@ int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int task_id) { | |||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| (*self_grad_operation_)(dy + start, in_x + start, dx + start, count); | |||||
| if (count > 0) { | |||||
| (*self_grad_operation_)(dy + start, in_x + start, dx + start, count); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -41,7 +41,9 @@ int AssignCPUKernel::Execute(int task_id) { | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| memcpy(&(x[start]), &(y[start]), count * sizeof(float)); | |||||
| if (count > 0) { | |||||
| memcpy(&(x[start]), &(y[start]), count * sizeof(float)); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -76,6 +76,7 @@ int BNGradCPUKernel::Execute(int task_id) { | |||||
| int total = spatial * batch; | int total = spatial * batch; | ||||
| int stride = UP_DIV(total, thread_num); | int stride = UP_DIV(total, thread_num); | ||||
| int count = MSMIN(stride, total - stride * task_id); | int count = MSMIN(stride, total - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| switch (stage) { | switch (stage) { | ||||
| case 0: { | case 0: { | ||||
| for (int job = task_id; job < 4; job += thread_num) { | for (int job = task_id; job < 4; job += thread_num) { | ||||
| @@ -108,6 +108,7 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | |||||
| float *mat_tmp = mat_workspace + mat_alloc_; | float *mat_tmp = mat_workspace + mat_alloc_; | ||||
| int stride = UP_DIV(batch, thread_num); | int stride = UP_DIV(batch, thread_num); | ||||
| int count = MSMIN(stride, batch - stride * task_id); | int count = MSMIN(stride, batch - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| @@ -115,6 +116,7 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | |||||
| #ifdef ENABLE_ARM | #ifdef ENABLE_ARM | ||||
| stride = UP_DIV(k_h * k_w, thread_num); | stride = UP_DIV(k_h * k_w, thread_num); | ||||
| count = MSMIN(stride, k_h * k_w - stride * task_id); | count = MSMIN(stride, k_h * k_w - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| start = stride * task_id; | start = stride * task_id; | ||||
| ConvDwFilterGrad(x_addr, dy_addr, dw_addr, start, count, conv_param); | ConvDwFilterGrad(x_addr, dy_addr, dw_addr, start, count, conv_param); | ||||
| #else | #else | ||||
| @@ -92,6 +92,7 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) { | |||||
| float *mat_workspace = workspace_temp + ws_size_; | float *mat_workspace = workspace_temp + ws_size_; | ||||
| int stride = UP_DIV(batch, thread_num); | int stride = UP_DIV(batch, thread_num); | ||||
| int count = MSMIN(stride, batch - stride * task_id); | int count = MSMIN(stride, batch - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| @@ -67,22 +67,24 @@ int DropoutCPUKernel::Execute(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| MS_LOG(ERROR) << "Dropout op_parameter_ nullptr"; | MS_LOG(ERROR) << "Dropout op_parameter_ nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (IsEval()) { | |||||
| std::copy(&(input_ptr[start]), &(input_ptr[end]), &(output_ptr[start])); | |||||
| } else { | |||||
| std::default_random_engine generator; | |||||
| std::bernoulli_distribution distribution(param->ratio_); | |||||
| for (size_t i = start; i < end; i++) { | |||||
| mask[i] = distribution(generator); | |||||
| output_ptr[i] = input_ptr[i] * mask[i] * scale_; | |||||
| if (count > 0) { | |||||
| if (IsEval()) { | |||||
| std::copy(&(input_ptr[start]), &(input_ptr[end]), &(output_ptr[start])); | |||||
| } else { | |||||
| std::default_random_engine generator; | |||||
| std::bernoulli_distribution distribution(param->ratio_); | |||||
| for (int i = start; i < end; i++) { | |||||
| mask[i] = distribution(generator); | |||||
| output_ptr[i] = input_ptr[i] * mask[i] * scale_; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -46,7 +46,7 @@ int NegGradCPUKernel::DoNegGrad(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| ElementNegative(dy + start, dx + start, count); | ElementNegative(dy + start, dx + start, count); | ||||
| @@ -50,7 +50,7 @@ int PowerGradCPUKernel::Execute(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| @@ -33,21 +33,21 @@ namespace mindspore::kernel { | |||||
| int SgdCPUKernel::ReSize() { return RET_OK; } | int SgdCPUKernel::ReSize() { return RET_OK; } | ||||
| int DoSgd(float *weight, float *accumulate, float *gradient, float learning_rate, float dampening, float moment, | int DoSgd(float *weight, float *accumulate, float *gradient, float learning_rate, float dampening, float moment, | ||||
| bool nesterov, size_t start, size_t end) { | |||||
| bool nesterov, int start, int end) { | |||||
| if (moment > 0.f) { | if (moment > 0.f) { | ||||
| if (nesterov) { | if (nesterov) { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - dampening); | accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - dampening); | ||||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - dampening); | accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - dampening); | ||||
| weight[i] -= accumulate[i] * learning_rate; | weight[i] -= accumulate[i] * learning_rate; | ||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| weight[i] -= gradient[i] * learning_rate; | weight[i] -= gradient[i] * learning_rate; | ||||
| } | } | ||||
| } | } | ||||
| @@ -55,14 +55,14 @@ int DoSgd(float *weight, float *accumulate, float *gradient, float learning_rate | |||||
| } | } | ||||
| int DoSgdInit(float *weight, float *accumulate, float *gradient, float *stat, float learning_rate, float dampening, | int DoSgdInit(float *weight, float *accumulate, float *gradient, float *stat, float learning_rate, float dampening, | ||||
| float moment, bool nesterov, size_t start, size_t end) { | |||||
| float moment, bool nesterov, int start, int end) { | |||||
| std::copy(&(gradient[start]), &(gradient[end]), &(accumulate[start])); | std::copy(&(gradient[start]), &(gradient[end]), &(accumulate[start])); | ||||
| if (nesterov) { | if (nesterov) { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (int i = start; i < end; ++i) { | |||||
| weight[i] -= accumulate[i] * learning_rate; | weight[i] -= accumulate[i] * learning_rate; | ||||
| } | } | ||||
| } | } | ||||
| @@ -80,7 +80,7 @@ int SgdCPUKernel::Execute(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| @@ -97,16 +97,18 @@ int SgdCPUKernel::ExecuteInit(int task_id) { | |||||
| auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | auto gradient = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | ||||
| float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | float moment = reinterpret_cast<float *>(in_tensors_.at(4)->MutableData())[0]; | ||||
| auto stat = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData()); | auto stat = reinterpret_cast<float *>(in_tensors_.at(5)->MutableData()); | ||||
| size_t length = in_tensors_.at(0)->ElementsNum(); | |||||
| int length = in_tensors_.at(0)->ElementsNum(); | |||||
| size_t stride = UP_DIV(length, thread_count_); | |||||
| size_t count = MSMIN(stride, length - stride * task_id); | |||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int count = MSMIN(stride, length - stride * task_id); | |||||
| size_t start = stride * task_id; | |||||
| size_t end = start + count; | |||||
| int start = stride * task_id; | |||||
| int end = start + count; | |||||
| DoSgdInit(weight, accumulate, gradient, stat, learning_rate, sgd_param_->dampening_, moment, | |||||
| sgd_param_->use_nesterov_, start, end); | |||||
| if (count > 0) { | |||||
| DoSgdInit(weight, accumulate, gradient, stat, learning_rate, sgd_param_->dampening_, moment, | |||||
| sgd_param_->use_nesterov_, start, end); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -40,7 +40,7 @@ int SmoothL1LossGradCPUKernel::Execute(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| count = (count < 0) ? 0 : count; | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/train/train_utils.h" | #include "src/train/train_utils.h" | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <memory> | #include <memory> | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "include/iterator.h" | #include "include/iterator.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| @@ -168,8 +168,7 @@ int TrainLoop::LoadPartialData(std::vector<tensor::MSTensor *> inputs, dataset:: | |||||
| } // namespace lite | } // namespace lite | ||||
| session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, | |||||
| int batch_size) { | |||||
| session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::TrainSession *train_session) { | |||||
| auto loop = new (std::nothrow) lite::TrainLoop(train_session); | auto loop = new (std::nothrow) lite::TrainLoop(train_session); | ||||
| return loop; | return loop; | ||||
| } | } | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "include/train/train_loop.h" | #include "include/train/train_loop.h" | ||||
| #include "include/train/metrics.h" | #include "include/train/metrics.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/datasets.h" | #include "include/datasets.h" | ||||
| #include "include/iterator.h" | #include "include/iterator.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <memory> | #include <memory> | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "src/train/train_model.h" | #include "src/train/train_model.h" | ||||
| #include "src/lite_session.h" | #include "src/lite_session.h" | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "include/train_session.h" | |||||
| #include "src/train/train_model.h" | #include "src/train/train_model.h" | ||||
| #include "src/lite_session.h" | #include "src/lite_session.h" | ||||
| #include "src/train/train_session.h" | #include "src/train/train_session.h" | ||||
| @@ -1,5 +1,5 @@ | |||||
| mini_alexnet | mini_alexnet | ||||
| #nin | |||||
| nin | |||||
| lenet | lenet | ||||
| mobilenetv1 | mobilenetv1 | ||||
| mobilenetv2 | mobilenetv2 | ||||
| @@ -10,5 +10,5 @@ effnet_tune | |||||
| googlenet | googlenet | ||||
| densenet | densenet | ||||
| shufflenetv2 | shufflenetv2 | ||||
| #xception | |||||
| # xception | |||||
| # LAST | # LAST | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| #include "include/context.h" | #include "include/context.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| @@ -33,7 +33,7 @@ | |||||
| #include "tools/common/flag_parser.h" | #include "tools/common/flag_parser.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "include/train_session.h" | |||||
| #include "include/train/train_session.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| enum MS_API DataType { kImage = 0, kBinary = 1 }; | enum MS_API DataType { kImage = 0, kBinary = 1 }; | ||||
| @@ -156,7 +156,6 @@ class MS_API NetTrain { | |||||
| std::cout << refOutput[j] << " "; | std::cout << refOutput[j] << " "; | ||||
| } | } | ||||
| for (int j = 0; j < size; j++) { | for (int j = 0; j < size; j++) { | ||||
| std::cout << std::endl; | |||||
| if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { | if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { | ||||
| std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; | std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; | ||||
| MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; | MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; | ||||
| @@ -126,6 +126,7 @@ constexpr auto kNameArgMinWithValue = "ArgMinWithValue"; | |||||
| constexpr auto kNameBatchMatMul = "BatchMatMul"; | constexpr auto kNameBatchMatMul = "BatchMatMul"; | ||||
| constexpr auto kNameFusedBatchNormEx = "FusedBatchNormEx"; | constexpr auto kNameFusedBatchNormEx = "FusedBatchNormEx"; | ||||
| constexpr auto kNameFusedBatchNormGradEx = "FusedBatchNormGradEx"; | constexpr auto kNameFusedBatchNormGradEx = "FusedBatchNormGradEx"; | ||||
| constexpr auto kNameFusedBatchNormGradCPU = "FusedBatchNormGradCPU"; | |||||
| constexpr auto kNameHSigmoid = "HSigmoid"; | constexpr auto kNameHSigmoid = "HSigmoid"; | ||||
| constexpr auto kNameHSigmoidGrad = "HSigmoidGrad"; | constexpr auto kNameHSigmoidGrad = "HSigmoidGrad"; | ||||
| constexpr auto kNameHSwish = "HSwish"; | constexpr auto kNameHSwish = "HSwish"; | ||||
| @@ -549,6 +550,7 @@ REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>) | REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>) | REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>) | REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapCommon<ops::BatchNormGrad>) | |||||
| REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation) | REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameGeLUGrad, MoveAttrMapActivationGrad) | REGIST_PRIMITIVE_ADJUST(kNameGeLUGrad, MoveAttrMapActivationGrad) | ||||
| REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation) | REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation) | ||||