Browse Source

!13134 [MS][LITE] tts_encoder models convert to weight quant failed.

From: @mengyuanli
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
cb90c963a8
7 changed files with 10 additions and 8 deletions
  1. +0
    -1
      mindspore/lite/nnacl/gather_parameter.h
  2. +1
    -2
      mindspore/lite/nnacl/infer/gather_infer.c
  3. +1
    -0
      mindspore/lite/nnacl/op_base.h
  4. +1
    -1
      mindspore/lite/src/mindrt_executor.cc
  5. +3
    -2
      mindspore/lite/src/ops/populate/where_populate.cc
  6. +2
    -0
      mindspore/lite/src/scheduler.cc
  7. +2
    -2
      mindspore/lite/tools/converter/quant_param_holder.h

+ 0
- 1
mindspore/lite/nnacl/gather_parameter.h View File

@@ -23,7 +23,6 @@ typedef struct GatherParameter {
// Primitive parameter
OpParameter op_parameter_;
int axis_;
int quant_type_;
} GatherParameter;

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_

+ 1
- 2
mindspore/lite/nnacl/infer/gather_infer.c View File

@@ -25,8 +25,7 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
const TensorC *indices = inputs[1];
TensorC *output = outputs[0];
output->data_type_ = input->data_type_;
GatherParameter *param = (GatherParameter *)parameter;
if (param->quant_type_ == QuantType_WeightQuant) {
if (parameter->quant_type_ == QuantType_WeightQuant) {
output->data_type_ = kNumberTypeFloat32;
}
output->format_ = input->format_;


+ 1
- 0
mindspore/lite/nnacl/op_base.h View File

@@ -80,6 +80,7 @@ typedef struct OpParameter {
bool infer_flag_;
int type_;
int thread_num_;
int quant_type_;
} OpParameter;

typedef struct QuantArg {


+ 1
- 1
mindspore/lite/src/mindrt_executor.cc View File

@@ -49,7 +49,7 @@ int MindrtExecutor::Prepare(const std::vector<kernel::LiteKernel *> &kernels) {

for (size_t j = 0; j < outTensorSize; j++) {
auto data =
std::make_shared<OpData<Tensor>>(opActors_[i]->GetAID(), kernels[i]->in_tensors()[j], static_cast<int>(j));
std::make_shared<OpData<Tensor>>(opActors_[i]->GetAID(), kernels[i]->out_tensors()[j], static_cast<int>(j));
outputData_.emplace_back(data);
}
}


+ 3
- 2
mindspore/lite/src/ops/populate/where_populate.cc View File

@@ -14,19 +14,20 @@
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/where_parameter.h"

namespace mindspore {
namespace lite {
namespace {
OpParameter *PopulateWhereParameter(const void *prim) {
OpParameter *where_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
WhereParameter *where_parameter = reinterpret_cast<WhereParameter *>(malloc(sizeof(WhereParameter)));
if (where_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Where parameter failed.";
return nullptr;
}
memset(where_parameter, 0, sizeof(OpParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
where_parameter->type_ = primitive->value_type();
where_parameter->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(where_parameter);
}
} // namespace


+ 2
- 0
mindspore/lite/src/scheduler.cc View File

@@ -135,6 +135,8 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_i
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << PrimitiveTypeName(GetPrimitiveType(primitive));
return RET_ERROR;
}
parameter->quant_type_ = node->quant_type_;

op_parameters_[node->output_indices_.at(0)] = parameter;
parameter->infer_flag_ = !(*infer_shape_interrupt);
auto ret = KernelInferShape(inputs, &outputs, parameter);


+ 2
- 2
mindspore/lite/tools/converter/quant_param_holder.h View File

@@ -81,7 +81,7 @@ class QuantParamHolder : public Value {
}

void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
if (index > this->input_quant_param_.size()) {
if (index >= this->input_quant_param_.size()) {
std::vector<schema::QuantParamT> place_quant(1);
this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(),
place_quant);
@@ -94,7 +94,7 @@ class QuantParamHolder : public Value {
}

void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) {
if (index > this->output_quant_param_.size()) {
if (index >= this->output_quant_param_.size()) {
std::vector<schema::QuantParamT> place_quant(1);
this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(),
place_quant);


Loading…
Cancel
Save