Browse Source

!1758 fix return parameter directly in net

Merge pull request !1758 from fary86/fix_return_parameter_directly_in_net
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
96ebda914c
1 changed files with 12 additions and 1 deletions
  1. +12
    -1
      mindspore/ccsrc/utils/convert_utils.cc

+ 12
- 1
mindspore/ccsrc/utils/convert_utils.cc View File

@@ -30,6 +30,7 @@
#include "pipeline/parse/parse_base.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "ir/param_value_py.h"
#include "utils/base_ref_extends.h"

namespace mindspore {
@@ -426,7 +427,17 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
<< " add Parameter count " << func_graph->hyper_param_count() << ".";
}
*ret_val = args[index];
if (index < args.size()) {
*ret_val = args[index];
} else {
auto param = dyn_cast<Parameter>(params[index]);
MS_EXCEPTION_IF_NULL(param);
if (!param->has_default()) {
MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
*ret_val = param_value->value().attr("data");
}
return true;
}
return false;


Loading…
Cancel
Save