|
|
|
@@ -133,14 +133,13 @@ std::tuple<ge::NodePtr, ge::ComputeGraphPtr> GenerateStubGeNode(const AnfNodePtr |
|
|
|
ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(kStubDataStructureName, ge_node_name); |
|
|
|
MS_EXCEPTION_IF_NULL(op_desc); |
|
|
|
for (size_t i = 1; i < cnode->size(); ++i) { |
|
|
|
auto &input = cnode->input(i); |
|
|
|
std::vector<int64_t> ge_shape; |
|
|
|
auto ms_shape = AnfAlgo::GetOutputInferShape(input, 0); |
|
|
|
auto ms_shape = AnfAlgo::GetInputDeviceShape(cnode, i - 1); |
|
|
|
std::transform(ms_shape.begin(), ms_shape.end(), std::back_inserter(ge_shape), |
|
|
|
[](size_t in) { return static_cast<int64_t>(in); }); |
|
|
|
op_desc->AddInputDesc( |
|
|
|
ge::GeTensorDesc(ge::GeShape(ge_shape), ge::Format::FORMAT_NCHW, |
|
|
|
transform::TransformUtil::ConvertDataType(AnfAlgo::GetOutputInferDataType(input, 0)))); |
|
|
|
transform::TransformUtil::ConvertDataType(AnfAlgo::GetInputDeviceDataType(cnode, i - 1)))); |
|
|
|
} |
|
|
|
|
|
|
|
// set node data type |
|
|
|
|