Browse Source

fix bugs in instance_norm and pad operator

tags/v1.1.0
zengxianglong 5 years ago
parent
commit
86ae625dc7
7 changed files with 43 additions and 24 deletions
  1. +9
    -0
      mindspore/lite/src/ops/instance_norm.cc
  2. +1
    -2
      mindspore/lite/src/ops/pad.cc
  3. +2
    -2
      mindspore/lite/src/ops/populate/instance_norm_populate.cc
  4. +13
    -15
      mindspore/lite/src/ops/populate/pad_populate.cc
  5. +3
    -0
      mindspore/lite/src/ops/primitive_c.cc
  6. +13
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc
  7. +2
    -1
      mindspore/lite/tools/common/node_util.cc

+ 9
- 0
mindspore/lite/src/ops/instance_norm.cc View File

@@ -16,6 +16,11 @@

#include "src/ops/instance_norm.h"
#include <memory>

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -60,6 +65,10 @@ int InstanceNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
}
float InstanceNorm::GetEpsilon() const { return this->primitive_->value_as_InstanceNorm()->epsilon(); }

PrimitiveC *InstanceNormCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<InstanceNorm>(primitive);
}
Registry InstanceNormRegistry(schema::PrimitiveType_InstanceNorm, InstanceNormCreator);
#endif
} // namespace lite
} // namespace mindspore

+ 1
- 2
mindspore/lite/src/ops/pad.cc View File

@@ -87,11 +87,10 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
}

std::vector<int> paddings;
if (GetPaddingMode() == static_cast<int>(schema::PaddingMode_CONSTANT)) {
if (inputs.size() == 1) {
paddings = GetPaddings();
} else {
// mirror pad
MS_ASSERT(inputs.size() == 2);
auto paddings_tensor = inputs.at(1);
int rank = static_cast<int>(inputs.front()->shape().size());
MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank);


+ 2
- 2
mindspore/lite/src/ops/populate/instance_norm_populate.cc View File

@@ -21,7 +21,7 @@

namespace mindspore {
namespace lite {
OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *PopulateInstanceNormParameter(const mindspore::lite::PrimitiveC *primitive) {
const auto param =
reinterpret_cast<mindspore::lite::InstanceNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
InstanceNormParameter *instance_norm_param =
@@ -37,6 +37,6 @@ OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive)
return reinterpret_cast<OpParameter *>(instance_norm_param);
}

Registry InstanceNormParameterRegistry(schema::PrimitiveType_L2Norm, PopulateInstanceNorm);
Registry InstanceNormParameterRegistry(schema::PrimitiveType_InstanceNorm, PopulateInstanceNormParameter);
} // namespace lite
} // namespace mindspore

+ 13
- 15
mindspore/lite/src/ops/populate/pad_populate.cc View File

@@ -32,23 +32,21 @@ OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive)
pad_param->op_parameter_.type_ = primitive->Type();
auto pad_node = reinterpret_cast<mindspore::lite::Pad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
pad_param->pad_mode_ = pad_node->GetPaddingMode();
if (pad_param->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
pad_param->constant_value_ = pad_node->GetConstantValue();
auto size = pad_node->GetPaddings().size();
if (size > MAX_PAD_SIZE) {
MS_LOG(ERROR) << "Invalid padding size: " << size;
free(pad_param);
return nullptr;
}
pad_param->constant_value_ = pad_node->GetConstantValue();
auto size = pad_node->GetPaddings().size();
if (size > MAX_PAD_SIZE) {
MS_LOG(ERROR) << "Invalid padding size: " << size;
free(pad_param);
return nullptr;
}

for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
pad_param->paddings_[i] = 0;
}
for (size_t i = 0; i < size; i++) {
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
}
pad_param->padding_length = MAX_PAD_SIZE;
for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
pad_param->paddings_[i] = 0;
}
for (size_t i = 0; i < size; i++) {
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
}
pad_param->padding_length = MAX_PAD_SIZE;

return reinterpret_cast<OpParameter *>(pad_param);
}


+ 3
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -143,6 +143,7 @@
#include "src/ops/audio_spectrogram.h"
#include "src/ops/mfcc.h"
#include "src/ops/identity.h"
#include "src/ops/instance_norm.h"

#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@@ -790,6 +791,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new AudioSpectrogram(primitive);
case schema::PrimitiveType_Mfcc:
return new Mfcc(primitive);
case schema::PrimitiveType_InstanceNorm:
return new InstanceNorm(primitive);

#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:


+ 13
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc View File

@@ -213,11 +213,20 @@ void PadCPUKernel::CalculateStrides() {
}

int PadCPUKernel::HandleMirrorPad() {
auto ret = CopyPaddingFromInput();
if (ret != RET_OK) {
return ret;
if (in_tensors_.size() == 1) {
auto input_shape = in_tensors_.at(0)->shape();
int rank = static_cast<int>(input_shape.size());
auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input_shape.data(), rank);
if (ret != RET_OK) {
return ret;
}
} else {
auto ret = CopyPaddingFromInput();
if (ret != RET_OK) {
return ret;
}
}
ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
auto ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
if (ret != RET_OK) {
return ret;
}


+ 2
- 1
mindspore/lite/tools/common/node_util.cc View File

@@ -46,7 +46,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_PReLU,
schema::PrimitiveType_BiasAdd};
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_InstanceNorm};

static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = {
#ifdef SUPPORT_TRAIN


Loading…
Cancel
Save