Browse Source

add anf pass

tags/v0.7.0-beta
zhengjun10 5 years ago
parent
commit
6cdc86383b
9 changed files with 295 additions and 4 deletions
  1. +18
    -2
      mindspore/lite/src/gllo/common/utils.cc
  2. +3
    -0
      mindspore/lite/src/gllo/common/utils.h
  3. +64
    -0
      mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc
  4. +38
    -0
      mindspore/lite/src/gllo/fusion/conv_activation_fusion.h
  5. +2
    -2
      mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc
  6. +126
    -0
      mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc
  7. +40
    -0
      mindspore/lite/src/gllo/fusion/conv_scale_fusion.h
  8. +2
    -0
      mindspore/lite/test/CMakeLists.txt
  9. +2
    -0
      mindspore/lite/tools/converter/CMakeLists.txt

+ 18
- 2
mindspore/lite/src/gllo/common/utils.cc View File

@@ -16,9 +16,10 @@
#include <vector>
#include <memory>
#include "src/gllo/common/utils.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "src/ir/primitive_t_value.h"
#include "frontend/operator/ops.h"

using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
namespace mindspore {
namespace opt {

@@ -74,7 +75,11 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
}
}
}

if (a.m_ptr->isa<lite::PrimitiveTValue>()) {
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveTValuePtr>();
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>();
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
}
return a == b;
}

@@ -203,5 +208,16 @@ void CheckInputSize(const CNodePtr &node, const int size) {
}
}

schema::PrimitiveType GetCNodeType(const CNodePtr &node) {
auto value_primitive = node->input(0);
auto value_node = value_primitive->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto value = value_node->value();
MS_ASSERT(value != nullptr);
auto primitive = value->cast<PrimitiveTValuePtr>();
MS_ASSERT(primitive != nullptr);
return primitive->GetPrimitiveT()->value.type;
}

} // namespace opt
} // namespace mindspore

+ 3
- 0
mindspore/lite/src/gllo/common/utils.h View File

@@ -21,6 +21,7 @@
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "src/gllo/common/pattern_engine.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace opt {
@@ -42,6 +43,8 @@ void CheckIfVarIsNull(const VarPtr &var);

void CheckInputSize(const CNodePtr &node, const int size);

schema::PrimitiveType GetCNodeType(const CNodePtr &node);

} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_


+ 64
- 0
mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc View File

@@ -0,0 +1,64 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/gllo/fusion/conv_activation_fusion.h"
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/gllo/common/utils.h"

namespace mindspore {
namespace opt {
const BaseRef ConvActivationFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
// conv2d inputs may be 2 or 3 inputs,match move to process
auto prim = new schema::PrimitiveT();
prim->value.type = primitive_type;
auto prim_value = std::make_shared<lite::PrimitiveTValue>(prim);

return VectorRef({prim_value, X});
}

const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
CheckIfFuncGraphIsNull(func_graph);

CheckIfAnfNodeIsNull(node);
auto act_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(act_node);
CheckInputSize(act_node, 2);

auto act_primitive = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(act_node->input(0));
if (act_primitive->GetPrimitiveT()->value.AsActivation()->type != activation_type) {
return node;
}
AnfNodePtr pre_node = act_node->input(1);
CheckIfAnfNodeIsNull(pre_node);
if (pre_node != nullptr && pre_node->isa<CNode>()) {
auto conv_node = pre_node->cast<CNodePtr>();
auto node_type = GetCNodeType(conv_node);
if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) {
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->activationType = activation_type;
return pre_node;
}
}
return node;
}
} // namespace opt
} // namespace mindspore

+ 38
- 0
mindspore/lite/src/gllo/fusion/conv_activation_fusion.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_

#include "src/gllo/common/optimizer.h"

namespace mindspore {
namespace opt {
class ConvActivationFusion : public PatternProcessPass {
public:
explicit ConvActivationFusion(bool multigraph = true,
schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU,
schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) : primitive_type(
primitive), activation_type(activation), PatternProcessPass("conv_activation_fusion", multigraph) {}
~ConvActivationFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
schema::PrimitiveType primitive_type;
schema::ActivationType activation_type;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_

+ 2
- 2
mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc View File

@@ -15,8 +15,8 @@
*/
#include "src/gllo/fusion/conv_biasadd_fusion.h"
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/gllo/common/utils.h"



+ 126
- 0
mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc View File

@@ -0,0 +1,126 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/gllo/fusion/conv_scale_fusion.h"
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "src/param_value_lite.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/gllo/common/utils.h"
#include "include/errorcode.h"

namespace mindspore {
namespace opt {
const BaseRef ConvScaleFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
// conv2d inputs may be 2 or 3 inputs,match move to process
auto prim = new schema::PrimitiveT();
prim->value.type = schema::PrimitiveType_Scale;
auto prim_value = std::make_shared<lite::PrimitiveTValue>(prim);

return VectorRef({prim_value, X});
}

const AnfNodePtr ConvScaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_LOG(DEBUG) << "conv activation pass process";
CheckIfFuncGraphIsNull(func_graph);

CheckIfAnfNodeIsNull(node);
auto scale_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(scale_node);
CheckInputSize(scale_node, 2);

AnfNodePtr pre_node = scale_node->input(1);
CheckIfAnfNodeIsNull(pre_node);
if (pre_node != nullptr && pre_node->isa<CNode>()) {
auto conv_node = pre_node->cast<CNodePtr>();
auto node_type = GetCNodeType(conv_node);
if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) {
return DoFusion(conv_node, scale_node);
}
}
return node;
}
const AnfNodePtr ConvScaleFusion::DoFusion(const CNodePtr &conv_node, const CNodePtr &scale_node) const {
if (scale_node->inputs().size() == 3) {
GetTransParam(scale_node->input(2), nullptr);
} else if (scale_node->inputs().size() == 4) {
// todo add bias fusion zhengjun10
GetTransParam(scale_node->input(2), scale_node->input(3));
} else {
MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString();
return nullptr;
}

AnfNodePtr conv_weight_node;
if (conv_node->inputs().size() == 3) {
conv_weight_node = conv_node->input(2);
} else {
MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString();
return nullptr;
}
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
auto weight_value = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param);
auto old_conv_weight = reinterpret_cast<const float *>(weight_value->tensor_addr());

auto new_conv_weight = new(std::nothrow) float[weight_value->tensor_shape_size()];
CalNewWeightTensor(old_conv_weight, new_conv_weight, weight_value->tensor_shape_size());
weight_value->set_tensor_addr(new_conv_weight);
return conv_node;
}

const lite::STATUS ConvScaleFusion::GetTransParam(const AnfNodePtr &scale_weight_node,
const AnfNodePtr &scale_bias_node) const {
if (!scale_weight_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "scale weight node not paramter node";
}
if (scale_bias_node != nullptr && !scale_bias_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "scale bias node not paramter node";
}
auto scale_weight_param = scale_weight_node->cast<ParameterPtr>()->default_param();
auto weight_value = std::dynamic_pointer_cast<ParamValueLite>(scale_weight_param);
auto weight_data = reinterpret_cast<const float *>(weight_value->tensor_addr());

if (0 != memcpy_s(trans_scale, kernel_nums * sizeof(float), weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s transScale failed";
return lite::RET_ERROR;
}
return lite::RET_OK;
}

const lite::STATUS ConvScaleFusion::CalNewWeightTensor(const float *oldWeightTensor, float *newWeightTensor,
const size_t tensor_shape_size) const {
MS_ASSERT(oldWeightTensor != nullptr);
if (0 != memset_s(newWeightTensor, tensor_shape_size * sizeof(float), 0, tensor_shape_size * sizeof(float))) {
MS_LOG(ERROR) << "memset newWeightData failed";
return lite::RET_ERROR;
}
if (kernel_nums == 0) {
MS_LOG(ERROR) << "kernel nums is 0";
return lite::RET_ERROR;
}
auto kernel_size = tensor_shape_size / kernel_nums;
for (size_t i = 0; i < kernel_nums; i++) {
for (size_t j = 0; j < kernel_size; j++) {
newWeightTensor[i * kernel_size + j] = oldWeightTensor[i * kernel_size + j] * trans_scale[i];
}
}
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

+ 40
- 0
mindspore/lite/src/gllo/fusion/conv_scale_fusion.h View File

@@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_

#include "src/gllo/common/optimizer.h"

namespace mindspore {
namespace opt {
class ConvScaleFusion : public PatternProcessPass {
public:
explicit ConvScaleFusion(bool multigraph = true) : PatternProcessPass("conv_scale_fusion", multigraph) {}
~ConvScaleFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
const AnfNodePtr DoFusion(const CNodePtr &, const CNodePtr &) const;
const lite::STATUS GetTransParam(const AnfNodePtr &, const AnfNodePtr &) const;
const lite::STATUS CalNewWeightTensor(const float *, float *, const size_t) const;
private:
float *trans_scale = nullptr;
int kernel_nums = 0;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_


+ 2
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -193,6 +193,8 @@ if(BUILD_CONVERTER)
${LITE_DIR}/src/gllo/common/visit.cc
${LITE_DIR}/src/gllo/common/utils.cc
${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_scale_fusion.cc
)
endif()
### train


+ 2
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -78,6 +78,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/visit.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_scale_fusion.cc
)

add_subdirectory(parser/caffe)


Loading…
Cancel
Save