|
|
|
@@ -98,22 +98,22 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) |
|
|
|
<< ", and the value is " << py::cast<py::str>(grads[i]) << "."; |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple grad_shape = grads[i].attr("shape"); |
|
|
|
py::object arg_dtype = py_args[i].attr("dtype"); |
|
|
|
py::object grad_dtype = grads[i].attr("dtype"); |
|
|
|
py::tuple arg_shape = py_args[i].attr("shape"); |
|
|
|
py::object arg_dtype = py_args[i].attr("dtype"); |
|
|
|
py::tuple grad_shape = grads[i].attr("shape"); |
|
|
|
if (!grad_dtype.equal(arg_dtype)) { |
|
|
|
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i |
|
|
|
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i |
|
|
|
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype) |
|
|
|
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << "."; |
|
|
|
} |
|
|
|
if (!grad_shape.equal(arg_shape)) { |
|
|
|
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i |
|
|
|
<< "th arg should have the same shape as the " << i << "th arg, but the " << i |
|
|
|
<< "th arg shape is: " << py::cast<py::str>(arg_shape) |
|
|
|
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << "."; |
|
|
|
} |
|
|
|
if (!grad_dtype.is(arg_dtype)) { |
|
|
|
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i |
|
|
|
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i |
|
|
|
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype) |
|
|
|
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << "."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -239,10 +239,7 @@ py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { |
|
|
|
|
|
|
|
bool PrimitivePy::HasComputeFunction() const { |
|
|
|
auto func = GetComputeFunction(); |
|
|
|
if (py::isinstance<py::none>(func)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
return !py::isinstance<py::none>(func); |
|
|
|
} |
|
|
|
|
|
|
|
PrimitivePtr PrimitivePy::Clone() { |
|
|
|
@@ -272,7 +269,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { |
|
|
|
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") |
|
|
|
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") |
|
|
|
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") |
|
|
|
.def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.") |
|
|
|
.def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.") |
|
|
|
.def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes, |
|
|
|
"Set primitive const input indexes.") |
|
|
|
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") |
|
|
|
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") |
|
|
|
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); |
|
|
|
|