Browse Source

modify onnx export

tags/v1.6.0
liuyang_655 4 years ago
parent
commit
e6c2cf7b73
3 changed files with 9 additions and 4 deletions
  1. +6
    -2
      mindspore/ccsrc/transform/express_ir/onnx_exporter.cc
  2. +2
    -1
      mindspore/python/mindspore/_checkparam.py
  3. +1
    -1
      mindspore/python/mindspore/nn/wrap/cell_wrapper.py

+ 6
- 2
mindspore/ccsrc/transform/express_ir/onnx_exporter.cc View File

@@ -806,8 +806,8 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node,
if (int_ptr == nullptr) {
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
if (tuple_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Got null pointer, the " << name
<< "Operator in your model is not support for exporting onnx.";
MS_LOG(EXCEPTION) << "Got null pointer, currently the " << name
<< " operator in your model is not support for exporting onnx.";
}
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
@@ -977,6 +977,10 @@ void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr &, const C
std::vector<int64_t> resize_size;

auto tuple_ptr = dyn_cast<ValueTuple>(prim->GetAttr("size"));
if (tuple_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Got null pointer, currently the " << prim->name()
<< " operator in your model is not support for exporting onnx.";
}

for (size_t i = 0; i < x_shape->shape().size() - kTwoNum; i++) {
resize_size.push_back(x_shape->shape()[i]);


+ 2
- 1
mindspore/python/mindspore/_checkparam.py View File

@@ -961,7 +961,8 @@ def args_type_check(*type_args, **type_kwargs):
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('The argument {} must be {}'.format(name, bound_types[name]))
raise TypeError("The argument {} must be {}, but got {}"
.format(name, bound_types[name], type(value)))
return func(*args, **kwargs)

return wrapper


+ 1
- 1
mindspore/python/mindspore/nn/wrap/cell_wrapper.py View File

@@ -702,7 +702,7 @@ class ParameterUpdate(Cell):
def __init__(self, param):
super(ParameterUpdate, self).__init__(auto_prefix=False)
if not isinstance(param, Parameter):
raise TypeError("For 'ParameterUpdate', 'param' must be 'Parameter', but got {}.".format(param))
raise TypeError("For 'ParameterUpdate', 'param' must be 'Parameter', but got {}.".format(type(param)))
self._param = param

def construct(self, x):


Loading…
Cancel
Save