|
|
|
@@ -78,12 +78,13 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { |
|
|
|
} |
|
|
|
for (size_t i = 1; i < node_inputs.size(); ++i) { |
|
|
|
auto input = GetRealInput(node_inputs[i]); |
|
|
|
|
|
|
|
if (HasAbstractMonad(input)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (input->isa<Parameter>()) { |
|
|
|
auto input_parameter = input->cast<ParameterPtr>(); |
|
|
|
is_parameter.push_back(ParameterRequireGrad(input_parameter)); |
|
|
|
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || IsValueNode<tensor::Tensor>(input) || |
|
|
|
IsValueNode<RefKey>(input)) { |
|
|
|
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) { |
|
|
|
is_parameter.push_back(false); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -174,6 +175,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) { |
|
|
|
|
|
|
|
// extract input element length |
|
|
|
for (auto &input : node_inputs) { |
|
|
|
if (HasAbstractMonad(input)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsValueNode<RefKey>(input)) { |
|
|
|
auto func_graph = node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
@@ -182,8 +186,7 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) { |
|
|
|
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; |
|
|
|
} |
|
|
|
inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); |
|
|
|
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || input->isa<Parameter>() || |
|
|
|
IsValueNode<tensor::Tensor>(input)) { |
|
|
|
} else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) { |
|
|
|
// extract input shape from parameter and apply node |
|
|
|
inputs_type_len.push_back(GetInputsTypeLen(input)); |
|
|
|
} |
|
|
|
|