|
|
|
@@ -50,6 +50,29 @@ bool IsSpecialType(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) { |
|
|
|
AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index); |
|
|
|
if (abstract == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { |
|
|
|
MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); |
|
|
|
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape |
|
|
|
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
*tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack()); |
|
|
|
if (*tensor_info == nullptr) { |
|
|
|
MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { |
|
|
|
@@ -93,6 +116,50 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li |
|
|
|
return new_abstract; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { |
|
|
|
MS_ASSERT(parameter != nullptr); |
|
|
|
auto old_abstract = parameter->abstract(); |
|
|
|
if (old_abstract == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(old_abstract)) { |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << parameter->name(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(old_abstract); |
|
|
|
|
|
|
|
auto type_ptr = abstract_tensor->element()->GetTypeTrack(); |
|
|
|
if (type_ptr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "type_ptr is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { |
|
|
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << parameter->name(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); |
|
|
|
std::vector<int32_t> shape; |
|
|
|
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), |
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); }); |
|
|
|
|
|
|
|
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); |
|
|
|
auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector); |
|
|
|
if (parameter->has_default()) { |
|
|
|
auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); |
|
|
|
new_tensor_info = lite::CreateTensorInfo(old_tensor_info->data_c(), old_tensor_info->Size(), |
|
|
|
old_tensor_info->shape(), old_tensor_info->data_type()); |
|
|
|
if (new_tensor_info == nullptr) { |
|
|
|
MS_LOG(ERROR) << "create tensor info failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
new_abstract->set_value(new_tensor_info); |
|
|
|
parameter->set_abstract(new_abstract); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void InferShapePass::FreeTensors(std::vector<lite::Tensor *> *tensors) { |
|
|
|
for (auto tensor : *tensors) { |
|
|
|
delete tensor; |
|
|
|
@@ -104,6 +171,12 @@ void InferShapePass::FreeTensors(std::vector<lite::Tensor *> *tensors) { |
|
|
|
STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors) { |
|
|
|
MS_ASSERT(cnode != nullptr); |
|
|
|
MS_ASSERT(input_tensors != nullptr); |
|
|
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
if (primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "primitive is nullptr: " << cnode->fullname_with_scope(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
const int WEIGHT_INDEX = 2; |
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
auto input = inputs[i]; |
|
|
|
@@ -117,28 +190,14 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); |
|
|
|
if (abstract == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { |
|
|
|
MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); |
|
|
|
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape |
|
|
|
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto param_value_lite = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack()); |
|
|
|
if (param_value_lite == nullptr) { |
|
|
|
MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr"; |
|
|
|
tensor::TensorPtr tensor_info; |
|
|
|
auto status = GetTensorInfoFromAbstract(&tensor_info, cnode, i); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "get tensor info failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<lite::Tensor> tensor = nullptr; |
|
|
|
if (param_value_lite->data_type() != kObjectTypeTensorType) { |
|
|
|
if (tensor_info->data_type() != kObjectTypeTensorType) { |
|
|
|
tensor = std::make_unique<lite::Tensor>(); |
|
|
|
} else { |
|
|
|
tensor = std::make_unique<lite::TensorList>(); |
|
|
|
@@ -149,30 +208,36 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int> shape; |
|
|
|
std::transform(param_value_lite->shape().begin(), param_value_lite->shape().end(), std::back_inserter(shape), |
|
|
|
std::transform(tensor_info->shape().begin(), tensor_info->shape().end(), std::back_inserter(shape), |
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); }); |
|
|
|
if (param_value_lite->data_type() != kObjectTypeTensorType) { |
|
|
|
if (tensor_info->data_type() != kObjectTypeTensorType) { |
|
|
|
tensor->set_shape(shape); |
|
|
|
tensor->set_data_type(param_value_lite->data_type()); |
|
|
|
tensor->set_data_type(tensor_info->data_type()); |
|
|
|
if (primitive->GetAttr(opt::kWeightFormat) != nullptr && i == WEIGHT_INDEX) { |
|
|
|
tensor->set_format(static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat)))); |
|
|
|
} else { |
|
|
|
tensor->set_format(schema::Format::Format_NHWC); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (utils::isa<ParameterPtr>(input)) { |
|
|
|
auto parameter = input->cast<ParameterPtr>(); |
|
|
|
if (parameter->has_default()) { |
|
|
|
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); |
|
|
|
if (param_value_lite->data_type() != kObjectTypeTensorType) { |
|
|
|
auto default_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); |
|
|
|
if (tensor_info->data_type() != kObjectTypeTensorType) { |
|
|
|
auto ret = tensor->MallocData(); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(ERROR) << "Malloc tensor data failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
ret = memcpy_s(tensor->MutableData(), tensor->Size(), tensor_info->data_c(), tensor_info->Size()); |
|
|
|
ret = |
|
|
|
memcpy_s(tensor->MutableData(), tensor->Size(), default_tensor_info->data_c(), default_tensor_info->Size()); |
|
|
|
if (tensor->Size() != 0 && ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "memcpy error: " << ret; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} else { |
|
|
|
int *data = reinterpret_cast<int *>(tensor_info->data_c()); |
|
|
|
int *data = reinterpret_cast<int *>(default_tensor_info->data_c()); |
|
|
|
auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor.get()); |
|
|
|
if (tensor_list->Decode(data) != RET_OK) { |
|
|
|
return RET_ERROR; |
|
|
|
@@ -301,6 +366,10 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { |
|
|
|
auto node_list = TopoSort(func_graph->get_return()); |
|
|
|
for (auto &node : node_list) { |
|
|
|
if (utils::isa<ParameterPtr>(node)) { |
|
|
|
int status = SetParameterAbstract(node->cast<ParameterPtr>()); |
|
|
|
if (status != RET_OK) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!utils::isa<CNodePtr>(node)) { |
|
|
|
|