Browse Source

add adapter of Conv3D and Conv3DTranspose operators for graphengine.

pull/15115/head
wangshuide2020 4 years ago
parent
commit
a0a260dca1
5 changed files with 96 additions and 0 deletions
  1. +4
    -0
      mindspore/ccsrc/transform/graph_ir/op_adapter_map.h
  2. +8
    -0
      mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc
  3. +57
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/nn_calculation_ops_declare.cc
  4. +19
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/nn_calculation_ops_declare.h
  5. +8
    -0
      mindspore/ccsrc/transform/graph_ir/util.cc

+ 4
- 0
mindspore/ccsrc/transform/graph_ir/op_adapter_map.h View File

@@ -60,6 +60,10 @@ constexpr const char kNameFlattenGrad[] = "FlattenGrad";
constexpr const char kNameConvolution[] = "Convolution";
constexpr const char kNameMaxPool3D[] = "MaxPool3D";
constexpr const char kNameMaxPool3DGrad[] = "MaxPool3DGrad";
constexpr const char kNameConv3DTransposeD[] = "Conv3DTranspose";
constexpr const char kNameConv3D[] = "Conv3D";
constexpr const char kNameConv3DBackpropInputD[] = "Conv3DBackpropInput";
constexpr const char kNameConv3DBackpropFilterD[] = "Conv3DBackpropFilter";
constexpr const char kNameBiasAdd[] = "BiasAdd";
constexpr const char kNameMaxPoolGrad[] = "MaxPoolGrad";
constexpr const char kNameRsqrtGrad[] = "RsqrtGrad";


+ 8
- 0
mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc View File

@@ -21,6 +21,7 @@
#include <algorithm>

#include "utils/utils.h"
#include "utils/check_convert_utils.h"
#include "transform/graph_ir/op_adapter_base.h"
#include "transform/graph_ir/io_format_map.h"

@@ -305,6 +306,13 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) {
if (iter->second == "format") {
ValuePtr format = prim->GetAttr("format");
MS_EXCEPTION_IF_NULL(format);
std::string type_name = prim->name();
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, "format", &format);
if (!converted) {
MS_LOG(ERROR) << "Fail to convert from attr value to string"
<< " for Op: " << type_name;
return ret;
}
return GetValue<std::string>(format);
}
return iter->second;


+ 57
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/nn_calculation_ops_declare.cc View File

@@ -64,6 +64,63 @@ ATTR_MAP(Conv2DBackpropFilterD) = {
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropFilterD, prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilterD))

// Conv3DTransposeD
INPUT_MAP(Conv3DTransposeD) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}, {4, INPUT_DESC(offset_w)}};
INPUT_ATTR_MAP(Conv3DTransposeD) = {
{5, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv3DTransposeD) = {
{"strides", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilations", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"groups", ATTR_DESC(groups, AnyTraits<int64_t>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"output_padding", ATTR_DESC(output_padding, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
};
OUTPUT_MAP(Conv3DTransposeD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv3DTransposeD, kNameConv3DTransposeD, ADPT_DESC(Conv3DTransposeD))

// Conv3D
INPUT_MAP(Conv3D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}, {4, INPUT_DESC(offset_w)}};
ATTR_MAP(Conv3D) = {
{"strides", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilations", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"groups", ATTR_DESC(groups, AnyTraits<int64_t>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"offset_x", ATTR_DESC(offset_x, AnyTraits<int64_t>())},
};
OUTPUT_MAP(Conv3D) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv3D, kNameConv3D, ADPT_DESC(Conv3D))

// Conv3DBackpropInputD
INPUT_MAP(Conv3DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}};
INPUT_ATTR_MAP(Conv3DBackpropInputD) = {
{3, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv3DBackpropInputD) = {
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>())},
{"dilations", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"groups", ATTR_DESC(groups, AnyTraits<int64_t>())},
};
OUTPUT_MAP(Conv3DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv3DBackpropInputD, kNameConv3DBackpropInputD, ADPT_DESC(Conv3DBackpropInputD))

// Conv3DBackpropFilterD
INPUT_MAP(Conv3DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}};
INPUT_ATTR_MAP(Conv3DBackpropFilterD) = {
{3, ATTR_DESC(filter_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv3DBackpropFilterD) = {
{"strides", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilations", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"groups", ATTR_DESC(groups, AnyTraits<int64_t>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
};
OUTPUT_MAP(Conv3DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv3DBackpropFilterD, kNameConv3DBackpropFilterD, ADPT_DESC(Conv3DBackpropFilterD))

// DepthwiseConv2D
INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}};
ATTR_MAP(DepthwiseConv2D) = {


+ 19
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/nn_calculation_ops_declare.h View File

@@ -40,6 +40,25 @@ DECLARE_OP_USE_ENUM(Conv2DBackpropFilterD)
DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropFilterD)
DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilterD)

DECLARE_OP_ADAPTER(Conv3DTransposeD)
DECLARE_OP_USE_ENUM(Conv3DTransposeD)
DECLARE_OP_USE_INPUT_ATTR(Conv3DTransposeD)
DECLARE_OP_USE_OUTPUT(Conv3DTransposeD)

DECLARE_OP_ADAPTER(Conv3D)
DECLARE_OP_USE_ENUM(Conv3D)
DECLARE_OP_USE_OUTPUT(Conv3D)

DECLARE_OP_ADAPTER(Conv3DBackpropInputD)
DECLARE_OP_USE_ENUM(Conv3DBackpropInputD)
DECLARE_OP_USE_INPUT_ATTR(Conv3DBackpropInputD)
DECLARE_OP_USE_OUTPUT(Conv3DBackpropInputD)

DECLARE_OP_ADAPTER(Conv3DBackpropFilterD)
DECLARE_OP_USE_ENUM(Conv3DBackpropFilterD)
DECLARE_OP_USE_INPUT_ATTR(Conv3DBackpropFilterD)
DECLARE_OP_USE_OUTPUT(Conv3DBackpropFilterD)

DECLARE_OP_ADAPTER(DepthwiseConv2D)
DECLARE_OP_USE_ENUM(DepthwiseConv2D)
DECLARE_OP_USE_OUTPUT(DepthwiseConv2D)


+ 8
- 0
mindspore/ccsrc/transform/graph_ir/util.cc View File

@@ -81,6 +81,14 @@ size_t TransformUtil::GetDataTypeSize(const MeDataType &type) {
GeFormat TransformUtil::ConvertFormat(const string &format) {
if (format == kOpFormat_NCHW) {
return GeFormat::FORMAT_NCHW;
} else if (format == kOpFormat_NDHWC) {
return GeFormat::FORMAT_NDHWC;
} else if (format == kOpFormat_NCDHW) {
return GeFormat::FORMAT_NCDHW;
} else if (format == kOpFormat_DHWNC) {
return GeFormat::FORMAT_DHWNC;
} else if (format == kOpFormat_DHWCN) {
return GeFormat::FORMAT_DHWCN;
} else if (format == kOpFormat_NC1HWC0) {
return GeFormat::FORMAT_NC1HWC0;
} else if (format == kOpFormat_NHWC) {


Loading…
Cancel
Save