|
|
@@ -86,7 +86,10 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne |
|
|
MS_EXCEPTION_IF_NULL(address); |
|
|
MS_EXCEPTION_IF_NULL(address); |
|
|
auto shape = AnfAlgo::GetOutputInferShape(node, output_index); |
|
|
auto shape = AnfAlgo::GetOutputInferShape(node, output_index); |
|
|
TypeId type_id = kNumberTypeFloat32; |
|
|
TypeId type_id = kNumberTypeFloat32; |
|
|
type_id = AnfAlgo::GetOutputInferDataType(node, output_index); |
|
|
|
|
|
|
|
|
type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); |
|
|
|
|
|
if (type_id == kTypeUnknown) { |
|
|
|
|
|
type_id = AnfAlgo::GetOutputInferDataType(node, output_index); |
|
|
|
|
|
} |
|
|
std::vector<int> temp_shape; |
|
|
std::vector<int> temp_shape; |
|
|
if (graph.IsInternalOutput(node, output_index)) { |
|
|
if (graph.IsInternalOutput(node, output_index)) { |
|
|
temp_shape.emplace_back(1); |
|
|
temp_shape.emplace_back(1); |
|
|
|