Browse Source

fix code example scripts and README

tags/v1.1.0
yoni 5 years ago
parent
commit
0d0d5853b5
11 changed files with 29 additions and 23 deletions
  1. +1
    -1
      mindspore/lite/examples/train_lenet/README.md
  2. +1
    -1
      mindspore/lite/examples/train_lenet/model/lenet_export.py
  3. +4
    -3
      mindspore/lite/examples/train_lenet/prepare_and_run.sh
  4. +3
    -3
      mindspore/lite/examples/transfer_learning/README.md
  5. +3
    -2
      mindspore/lite/examples/transfer_learning/prepare_and_run.sh
  6. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc
  7. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc
  8. +3
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.cc
  9. +4
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc
  10. +1
    -1
      mindspore/lite/src/train/train_session.h
  11. +4
    -4
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc

+ 1
- 1
mindspore/lite/examples/train_lenet/README.md View File

@@ -65,7 +65,7 @@ where:
- DATASET_PATH is the path to the [dataset](#dataset),
- MINDSPORE_DOCKER is the image name of the docker that runs [MindSpore](#environment-requirements). If not provided MindSpore will be run locally
- REALEASE.tar.gz is a pointer to the MindSpore ToD release tar ball. If not provided, the script will attempt to find MindSpore ToD compilation output
- target is defaulted to arm64, i.e., on-device. If x86 is provided, the demo will be run locally. Note that infrastructure is not optimized for device
- target is defaulted to arm64, i.e., on-device. If x86 is provided, the demo will be run locally. Note that infrastructure is not optimized for running on x86. Also, note that user needs to call "make clean" when switching betweeen targets.

# Script Detailed Description



+ 1
- 1
mindspore/lite/examples/train_lenet/model/lenet_export.py View File

@@ -24,7 +24,7 @@ from train_utils import TrainWrap

n = LeNet5()
n.set_train()
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)

BATCH_SIZE = 32
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)


+ 4
- 3
mindspore/lite/examples/train_lenet/prepare_and_run.sh View File

@@ -12,7 +12,6 @@ checkopts()
MNIST_DATA_PATH=""
while getopts 'D:d:r:t:' opt
do
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
case "${opt}" in
D)
MNIST_DATA_PATH=$OPTARG
@@ -70,7 +69,7 @@ PACKAGE=package-${TARGET}

rm -rf ${PACKAGE}
mkdir -p ${PACKAGE}/model
cp model/*.ms ${PACKAGE}/model
cp model/*.ms ${PACKAGE}/model || exit 1

# Copy the running script to the package
cp scripts/*.sh ${PACKAGE}/
@@ -85,7 +84,7 @@ mv mindspore-*/* msl/
rm -rf mindspore-*

# Copy the dataset to the package
cp -r ${MNIST_DATA_PATH} ${PACKAGE}/dataset
cp -r $MNIST_DATA_PATH ${PACKAGE}/dataset || exit 1

echo "==========Compiling============"
make TARGET=${TARGET}
@@ -94,6 +93,8 @@ make TARGET=${TARGET}
mv bin ${PACKAGE}/ || exit 1

if [ "${TARGET}" == "arm64" ]; then
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${PACKAGE}/lib/ || exit 1

echo "=======Pushing to device======="
adb push ${PACKAGE} /data/local/tmp/



+ 3
- 3
mindspore/lite/examples/transfer_learning/README.md View File

@@ -29,7 +29,7 @@ For this demo we will use only the [validation data of small images](http://plac
- Dataiset format:jpg files
- Note:In the current release, data is customely loaded using a proprietary DataSet class (provided in dataset.cc). In the upcoming releases loading will be done using MindSpore MindData infrastructure. In order to fit the data to the model it will be preprocessed using [ImageMagick convert tool](https://imagemagick.org/), namely croping and converting to bmp format.
- Note: Only 10 classes out of the 365 will be used in this demo
- Note: 60% of the data will be used for training and 20% will be used for testing and the remaining 20% for validation
- Note: 60% of the data will be used for training, 20% will be used for testing and the remaining 20% for validation

- The original dataset directory structure is as follows:

@@ -68,7 +68,7 @@ where:
- DATASET_PATH is the path to the [dataset](#dataset),
- MINDSPORE_DOCKER is the image name of the docker that runs [MindSpore](#environment-requirements). If not provided MindSpore will be run locally
- REALEASE.tar.gz is a pointer to the MindSpore ToD release tar ball. If not provided, the script will attempt to find MindSpore ToD compilation output
- target is defaulted to arm64, i.e., on-device. If x86 is provided, the demo will be run locally. Note that infrastructure is not optimized for device
- target is defaulted to arm64, i.e., on-device. If x86 is provided, the demo will be run locally. Note that infrastructure is not optimized for running on x86. Also, note that user needs to call "make clean" when switching betweeen targets.

# Script Detailed Description

@@ -82,7 +82,7 @@ See how to run the script and paramaters definitions in the [Quick Start Section

## Preparing the model

Within the model folder a `prepare_model.sh` script uses MindSpore infrastructure to export the model into a `.mindir` file. The user can specify a docker image on which MindSpore is installed. Otherwise, the pyhton script will be run locally.
Within the model folder a `prepare_model.sh` script uses MindSpore infrastructure to export the model into a `.mindir` file. The user can specify a docker image on which MindSpore is installed. Otherwise, the pyhton script will be run locally. As explained above, the head of the network is pretrained and a `.ckpt` file should be loaded to the head network. In the first time the script is run, it attempts to download the `.ckpt` file using `wget` command.
The script then converts the `.mindir` to a `.ms` format using the MindSpore ToD converter.
The script accepts a tar ball where the converter resides. Otherwise, the script will attempt to find the converter in the MindSpore ToD build output directory.



+ 3
- 2
mindspore/lite/examples/transfer_learning/prepare_and_run.sh View File

@@ -12,7 +12,6 @@ checkopts()
PLACES_DATA_PATH=""
while getopts 'D:d:r:t:' opt
do
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
case "${opt}" in
D)
PLACES_DATA_PATH=$OPTARG
@@ -69,7 +68,7 @@ PACKAGE=package-${TARGET}

rm -rf ${PACKAGE}
mkdir -p ${PACKAGE}/model
cp model/*.ms ${PACKAGE}/model
cp model/*.ms ${PACKAGE}/model || exit 1

# Copy the running script to the package
cp scripts/*.sh ${PACKAGE}/
@@ -94,6 +93,8 @@ make TARGET=${TARGET}
mv bin ${PACKAGE}/ || exit 1

if [ "${TARGET}" == "arm64" ]; then
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${PACKAGE}/lib/ || exit 1

echo "=======Pushing to device======="
adb push ${PACKAGE} /data/local/tmp/



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc View File

@@ -37,7 +37,7 @@ int BiasGradCPUKernel::Init() {
bias_param->out_shape_[i] = 1; // 1 dimension for N,H,W,
}
bias_param->out_shape_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1];
for (int i = bias_param->ndim_; i < 4; i++) {
for (auto i = bias_param->ndim_; i < 4; i++) {
bias_param->in_shape0_[i] = 0;
bias_param->out_shape_[i] = 0;
}


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc View File

@@ -35,6 +35,10 @@ int PoolingGradCPUKernel::Init() {
auto in_shape = in_tensors_.at(0)->shape();
auto out_shape = in_tensors_.at(1)->shape();

if (pool_param->pool_mode_ == PoolMode_AvgPool) {
out_shape = in_tensors_.at(2)->shape();
}

int input_h = in_shape.at(1);
int input_w = in_shape.at(2);



+ 3
- 3
mindspore/lite/src/runtime/kernel/arm/fp32_grad/sigmoid_cross_entropy_with_logits.cc View File

@@ -35,9 +35,9 @@ int SigmoidCrossEntropyWithLogitsCPUKernel::Execute(int task_id) {
auto *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
const size_t tensor_len = in_tensors_.at(0)->ElementsNum();

float zero = 0.0f;
float one = 1.0f;
float two = 2.0f;
const float zero = 0.0f;
const float one = 1.0f;
const float two = 2.0f;

for (uint64_t i = 0; i < tensor_len; ++i) {
if (logits[i] >= zero) {


+ 4
- 4
mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc View File

@@ -47,7 +47,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *
total_loss -= logf(losses[i * param->number_of_classes_ + label]);
}
}
output[0] = total_loss / param->batch_size_;
output[0] = total_loss / static_cast<float>(param->batch_size_);
return RET_OK;
}

@@ -67,9 +67,9 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *lab
for (size_t j = 0; j < param->number_of_classes_; ++j) {
size_t index = row_start + j;
if (j == label) {
grads[index] = (losses[index] - 1) / param->batch_size_;
grads[index] = (losses[index] - 1) / static_cast<float>(param->batch_size_);
} else {
grads[index] = losses[index] / param->batch_size_;
grads[index] = losses[index] / static_cast<float>(param->batch_size_);
}
}
}
@@ -138,7 +138,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() {
size_t data_size = in_tensors_.at(0)->ElementsNum();
set_workspace_size((data_size + dims.at(0)) * sizeof(float));
sm_params_.n_dim_ = 2;
sm_params_.element_size_ = data_size;
sm_params_.element_size_ = static_cast<int>(data_size);
sm_params_.axis_ = 1;
for (size_t i = 0; i < dims.size(); i++) sm_params_.input_shape_[i] = dims.at(i);



+ 1
- 1
mindspore/lite/src/train/train_session.h View File

@@ -77,7 +77,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
return lite::LiteSession::GetOutputByTensorName(tensor_name);
}
int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override {
return lite::LiteSession::Resize(inputs, dims);
return lite::RET_ERROR;
}

protected:


+ 4
- 4
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc View File

@@ -139,7 +139,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) {
lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x);
x_tensor.set_data(input1_data);

std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor};
std::vector<lite::Tensor *> inputs = {&x_tensor, &x_tensor, &dy_tensor};

auto output_data = new float[output_data_size];
ASSERT_NE(output_data, nullptr);
@@ -209,7 +209,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) {
lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x);
x_tensor.set_data(input1_data);

std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor};
std::vector<lite::Tensor *> inputs = {&x_tensor, &x_tensor, &dy_tensor};

std::vector<int> dim_dx({3, 28, 28, 3});
lite::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx);
@@ -282,7 +282,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride2Fp32) {
lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x);
ASSERT_EQ(out_tensor.MallocData(), 0);
float *out_data = static_cast<float *>(out_tensor.MutableData());
std::vector<lite::Tensor *> inputs = {&yt_tensor, &x_tensor};
std::vector<lite::Tensor *> inputs = {&x_tensor, &yt_tensor, &yt_tensor};
std::vector<lite::Tensor *> outputs = {&out_tensor};

lite::InnerContext context;
@@ -349,7 +349,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride3Fp32) {
ASSERT_EQ(out_tensor.MallocData(), 0);
auto out_data = static_cast<float *>(out_tensor.MutableData());

std::vector<lite::Tensor *> inputs = {&yt_tensor, &x_tensor};
std::vector<lite::Tensor *> inputs = {&x_tensor, &yt_tensor, &yt_tensor};
std::vector<lite::Tensor *> outputs = {&out_tensor};

lite::InnerContext context;


Loading…
Cancel
Save