|
|
|
@@ -15,7 +15,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "src/ops/pad.h" |
|
|
|
|
|
|
|
#include <string> |
|
|
|
#ifndef PRIMITIVE_WRITEABLE |
|
|
|
#include "src/ops/ops_register.h" |
|
|
|
#endif |
|
|
|
@@ -32,6 +32,52 @@ void Pad::SetPaddingMode(int padding_mode) { |
|
|
|
this->primitive_->value.AsPad()->paddingMode = (schema::PaddingMode)padding_mode; |
|
|
|
} |
|
|
|
void Pad::SetConstantValue(float constant_value) { this->primitive_->value.AsPad()->constantValue = constant_value; } |
|
|
|
int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { |
|
|
|
if (this->primitive_ == nullptr) { |
|
|
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT; |
|
|
|
if (this->primitive_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new primitiveT failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
this->primitive_->value.type = schema::PrimitiveType_Pad; |
|
|
|
} |
|
|
|
if (this->primitive_->value.type != schema::PrimitiveType_Pad) { |
|
|
|
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (this->primitive_->value.value == nullptr) { |
|
|
|
auto attr = new (std::nothrow) schema::PadT(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new primitiveT value failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
string paddingmode = "REFLECT"; |
|
|
|
if (prim.GetAttr("mode") == nullptr) { |
|
|
|
MS_LOG(ERROR) << "get mode failed!"; |
|
|
|
delete this->primitive_; |
|
|
|
delete attr; |
|
|
|
this->primitive_ = nullptr; |
|
|
|
attr = nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} else { |
|
|
|
paddingmode = GetValue<string>(prim.GetAttr("mode")); |
|
|
|
} |
|
|
|
if (paddingmode == "REFLECT") { |
|
|
|
attr->paddingMode = schema::PaddingMode_REFLECT; |
|
|
|
} else if (paddingmode == "SYMMETRIC") { |
|
|
|
attr->paddingMode = schema::PaddingMode_SYMMETRIC; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "model type not supported!"; |
|
|
|
delete this->primitive_; |
|
|
|
delete attr; |
|
|
|
this->primitive_ = nullptr; |
|
|
|
attr = nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
this->primitive_->value.value = attr; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
#else |
|
|
|
|
|
|
|
@@ -94,14 +140,22 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) |
|
|
|
auto paddings_tensor = inputs.at(1); |
|
|
|
int rank = static_cast<int>(inputs.front()->shape().size()); |
|
|
|
MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank); |
|
|
|
int *paddings_data = reinterpret_cast<int *>(paddings_tensor->MutableData()); |
|
|
|
if (paddings_data == nullptr) { |
|
|
|
if (paddings_tensor->MutableData() == nullptr) { |
|
|
|
return RET_INFER_ERR; |
|
|
|
} |
|
|
|
paddings.clear(); |
|
|
|
for (auto i = 0; i < rank; ++i) { |
|
|
|
paddings.emplace_back(paddings_data[i * 2]); |
|
|
|
paddings.emplace_back(paddings_data[i * 2 + 1]); |
|
|
|
if (paddings_tensor->data_type() == mindspore::kNumberTypeInt64) { |
|
|
|
auto paddings_data = reinterpret_cast<int64_t *>(paddings_tensor->MutableData()); |
|
|
|
for (auto i = 0; i < rank; ++i) { |
|
|
|
paddings.emplace_back(paddings_data[i * 2]); |
|
|
|
paddings.emplace_back(paddings_data[i * 2 + 1]); |
|
|
|
} |
|
|
|
} else if (paddings_tensor->data_type() == mindspore::kNumberTypeInt32) { |
|
|
|
auto paddings_data = reinterpret_cast<int32_t *>(paddings_tensor->MutableData()); |
|
|
|
for (auto i = 0; i < rank; ++i) { |
|
|
|
paddings.emplace_back(paddings_data[i * 2]); |
|
|
|
paddings.emplace_back(paddings_data[i * 2 + 1]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|