Browse Source

fix bug

tags/v1.1.0
xuanyue 5 years ago
parent
commit
bd4568c88d
4 changed files with 23 additions and 3 deletions
  1. +9
    -0
      mindspore/lite/src/ops/instance_norm.cc
  2. +3
    -0
      mindspore/lite/src/ops/primitive_c.cc
  3. +10
    -2
      mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc
  4. +1
    -1
      mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc

+ 9
- 0
mindspore/lite/src/ops/instance_norm.cc View File

@@ -16,6 +16,11 @@

#include "src/ops/instance_norm.h"
#include <memory>

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -60,6 +65,10 @@ int InstanceNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
}
float InstanceNorm::GetEpsilon() const { return this->primitive_->value_as_InstanceNorm()->epsilon(); }

PrimitiveC *InstanceNormCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<InstanceNorm>(primitive);
}
Registry InstanceNormRegistry(schema::PrimitiveType_InstanceNorm, InstanceNormCreator);
#endif
} // namespace lite
} // namespace mindspore

+ 3
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -143,6 +143,7 @@
#include "src/ops/audio_spectrogram.h"
#include "src/ops/mfcc.h"
#include "src/ops/identity.h"
#include "src/ops/instance_norm.h"

#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@@ -747,6 +748,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new AudioSpectrogram(primitive);
case schema::PrimitiveType_Mfcc:
return new Mfcc(primitive);
case schema::PrimitiveType_InstanceNorm:
return new InstanceNorm(primitive);

#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:


+ 10
- 2
mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc View File

@@ -71,6 +71,10 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
MS_ASSERT(nullptr != meta_graph_);
MS_ASSERT(nullptr != cNode);
auto primitiveCValue = PrimitiveC::Create(cNode->primitive.release());
if (primitiveCValue == nullptr) {
MS_LOG(ERROR) << "fail to convert primitive";
return nullptr;
}
cNode->primitive = nullptr;
// add quant parameter
if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) {
@@ -156,8 +160,12 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
MS_ASSERT(nullptr != func_graph_);
for (const auto &cNode : meta_graph_->nodes) {
MS_ASSERT(nullptr != cNode);

std::vector<AnfNodePtr> op_inputs = {ConvertPrimitive(cNode)};
auto anf_primitive = ConvertPrimitive(cNode);
if (anf_primitive == nullptr) {
MS_LOG(ERROR) << "cannot obtain anf primitive";
return RET_NULL_PTR;
}
std::vector<AnfNodePtr> op_inputs = {anf_primitive};
for (unsigned int j : cNode->inputIndex) {
auto node = GetNode(j);
if (nullptr == node) {


+ 1
- 1
mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc View File

@@ -141,7 +141,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
auto matmul_cvalue = lite::PrimitiveC::Create(matmul_primitive.release());
// get matmul quantParams
std::vector<schema::QuantParamT> jointed_quant_params;
for (int i = 1; i < 9; i++) {
for (size_t i = 1; i < stack_cnode->inputs().size(); i++) {
auto fullconnect_node2 = stack_cnode->input(i)->cast<CNodePtr>();
auto fc_prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(fullconnect_node2->input(0));
auto fc_input_quantParams = fc_prim->GetInputQuantParams();


Loading…
Cancel
Save