|
|
|
@@ -65,6 +65,9 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } |
|
|
|
size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } |
|
|
|
|
|
|
|
std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const { |
|
|
|
if (input_reshape_type_.empty()) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
if (input_index >= input_reshape_type_.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " |
|
|
|
<< input_reshape_type_.size(); |
|
|
|
@@ -73,6 +76,9 @@ std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { |
|
|
|
if (output_reshape_type_.empty()) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
if (output_index >= output_reshape_type_.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " |
|
|
|
<< output_reshape_type_.size(); |
|
|
|
|