|
|
|
@@ -218,16 +218,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, |
|
|
|
std::vector<TypeId> *node_mix_precision_datatype) { |
|
|
|
AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); |
|
|
|
MS_EXCEPTION_IF_NULL(cur_input); |
|
|
|
TypeId input_origin_type; |
|
|
|
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) { |
|
|
|
// weight |
|
|
|
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); |
|
|
|
} else if (cur_input->isa<ValueNode>()) { |
|
|
|
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); |
|
|
|
} else { |
|
|
|
// feature map |
|
|
|
input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); |
|
|
|
} |
|
|
|
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); |
|
|
|
AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); |
|
|
|
node_mix_precision_datatype->push_back(input_origin_type); |
|
|
|
} |
|
|
|
@@ -297,6 +288,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat |
|
|
|
return !kernel_match_datatype_idx->empty(); |
|
|
|
} |
|
|
|
|
|
|
|
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index, |
|
|
|
const std::vector<int> &node_mix_precision_datatype_index) { |
|
|
|
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && |
|
|
|
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; |
|
|
|
} |
|
|
|
|
|
|
|
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index, |
|
|
|
const std::vector<TypeId> &node_mix_precision_datatype, |
|
|
|
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes, |
|
|
|
@@ -329,7 +326,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci |
|
|
|
if (i >= datatype_indexes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); |
|
|
|
} |
|
|
|
if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) { |
|
|
|
if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { |
|
|
|
iter = kernel_match_datatype_idx->erase(iter); |
|
|
|
} else { |
|
|
|
++iter; |
|
|
|
@@ -376,6 +373,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index, |
|
|
|
bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, |
|
|
|
kernel_support_datatype, kernel_match_datatype_idx); |
|
|
|
if (selected_ret) { |
|
|
|
*precision_reduce = false; |
|
|
|
return; |
|
|
|
} |
|
|
|
if (context_ptr->enable_reduce_precision()) { |
|
|
|
|