|
|
@@ -523,38 +523,39 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no |
|
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { |
|
|
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
TypePtr type_ptr = node->Type(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
|
|
|
if (type_ptr->isa<TensorType>() && output_idx == 0) { |
|
|
|
|
|
auto tensor_ptr = type_ptr->cast<TensorTypePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr); |
|
|
|
|
|
TypePtr elem = tensor_ptr->element(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(elem); |
|
|
|
|
|
return elem->type_id(); |
|
|
|
|
|
} else if (type_ptr->isa<Tuple>()) { |
|
|
|
|
|
auto tuple_ptr = type_ptr->cast<TuplePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_ptr); |
|
|
|
|
|
if (output_idx >= tuple_ptr->size()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); |
|
|
|
|
|
} |
|
|
|
|
|
auto tuple_i = (*tuple_ptr)[output_idx]; |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_i); |
|
|
|
|
|
if (tuple_i->isa<TensorType>()) { |
|
|
|
|
|
auto tensor_ptr = tuple_i->cast<TensorTypePtr>(); |
|
|
|
|
|
|
|
|
auto get_single_type = [](const TypePtr &type_ptr) -> TypeId { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
|
|
|
if (type_ptr->isa<TensorType>()) { |
|
|
|
|
|
auto tensor_ptr = type_ptr->cast<TensorTypePtr>(); |
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr); |
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr); |
|
|
TypePtr elem = tensor_ptr->element(); |
|
|
TypePtr elem = tensor_ptr->element(); |
|
|
MS_EXCEPTION_IF_NULL(elem); |
|
|
MS_EXCEPTION_IF_NULL(elem); |
|
|
return elem->type_id(); |
|
|
return elem->type_id(); |
|
|
} else if (tuple_i->isa<Number>()) { |
|
|
|
|
|
return tuple_i->type_id(); |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); |
|
|
|
|
|
return tuple_i->type_id(); |
|
|
|
|
|
} |
|
|
} |
|
|
} else if (type_ptr->isa<Number>()) { |
|
|
|
|
|
|
|
|
if (type_ptr->isa<Number>()) { |
|
|
|
|
|
return type_ptr->type_id(); |
|
|
|
|
|
} |
|
|
return type_ptr->type_id(); |
|
|
return type_ptr->type_id(); |
|
|
|
|
|
}; |
|
|
|
|
|
auto get_tuple_type = [get_single_type](const TypePtr &type_ptr, size_t output_idx) -> TypeId { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
|
|
|
if (!type_ptr->isa<Tuple>()) { |
|
|
|
|
|
return get_single_type(type_ptr); |
|
|
|
|
|
} |
|
|
|
|
|
auto tuple_ptr = type_ptr->cast<TuplePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_ptr); |
|
|
|
|
|
if (output_idx >= tuple_ptr->size()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); |
|
|
|
|
|
} |
|
|
|
|
|
return get_single_type((*tuple_ptr)[output_idx]); |
|
|
|
|
|
}; |
|
|
|
|
|
TypePtr type_ptr = node->Type(); |
|
|
|
|
|
if (type_ptr->isa<RefType>()) { |
|
|
|
|
|
auto ref_type_ptr = type_ptr->cast<RefTypePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ref_type_ptr); |
|
|
|
|
|
return get_tuple_type(ref_type_ptr->subtype(), output_idx); |
|
|
} |
|
|
} |
|
|
return type_ptr->type_id(); |
|
|
|
|
|
|
|
|
return get_tuple_type(type_ptr, output_idx); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { |
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { |
|
|
|