|
|
@@ -24,7 +24,44 @@ |
|
|
namespace mindspore::lite { |
|
|
namespace mindspore::lite { |
|
|
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, |
|
|
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, |
|
|
std::vector<schema::TensorT *> *outputs) { |
|
|
std::vector<schema::TensorT *> *outputs) { |
|
|
|
|
|
auto p = GetCNodePrimitive(cnodePtr); |
|
|
auto attr = std::make_unique<schema::DepthwiseConv2DT>(); |
|
|
auto attr = std::make_unique<schema::DepthwiseConv2DT>(); |
|
|
|
|
|
|
|
|
|
|
|
auto format = GetValue<std::string>(p->GetAttr("data_format")); |
|
|
|
|
|
if (format == "NCHW") { |
|
|
|
|
|
attr->format = schema::Format_NCHW; |
|
|
|
|
|
} else if (format == "NHWC") { |
|
|
|
|
|
attr->format = schema::Format_NHWC; |
|
|
|
|
|
} else { |
|
|
|
|
|
attr->format = schema::Format_NUM_OF_FORMAT; |
|
|
|
|
|
} |
|
|
|
|
|
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pads")); |
|
|
|
|
|
attr->padUp = pad_list[0]; |
|
|
|
|
|
attr->padDown = pad_list[1]; |
|
|
|
|
|
attr->padLeft = pad_list[2]; |
|
|
|
|
|
attr->padRight = pad_list[3]; |
|
|
|
|
|
|
|
|
|
|
|
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); |
|
|
|
|
|
attr->dilateH = dilation[0]; |
|
|
|
|
|
attr->dilateW = dilation[1]; |
|
|
|
|
|
|
|
|
|
|
|
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); |
|
|
|
|
|
attr->kernelH = kernel_size[0]; |
|
|
|
|
|
attr->kernelW = kernel_size[1]; |
|
|
|
|
|
|
|
|
|
|
|
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); |
|
|
|
|
|
attr->strideH = stride[2]; |
|
|
|
|
|
attr->strideW = stride[3]; |
|
|
|
|
|
|
|
|
|
|
|
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); |
|
|
|
|
|
if (pad_mode == "valid") { |
|
|
|
|
|
attr->padMode = schema::PadMode_VALID; |
|
|
|
|
|
} else if (pad_mode == "same") { |
|
|
|
|
|
attr->padMode = schema::PadMode_SAME; |
|
|
|
|
|
} else { |
|
|
|
|
|
attr->padMode = schema::PadMode_NOTSET; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
node->nodeType = schema::NodeType_CNode; |
|
|
node->nodeType = schema::NodeType_CNode; |
|
|
node->primitive = std::make_unique<schema::PrimitiveT>(); |
|
|
node->primitive = std::make_unique<schema::PrimitiveT>(); |
|
|
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; |
|
|
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; |
|
|
|