| @@ -288,6 +288,7 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { | |||||
| } | } | ||||
| bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) { | bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) { | ||||
| MS_EXCEPTION_IF_NULL(dtypes); | |||||
| auto signature = prim->signatures(); | auto signature = prim->signatures(); | ||||
| bool has_sig_dtype = false; | bool has_sig_dtype = false; | ||||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes), | (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes), | ||||
| @@ -733,20 +734,29 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { | |||||
| AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | ||||
| abstract::AbstractBasePtrList *args_spec_list) { | abstract::AbstractBasePtrList *args_spec_list) { | ||||
| MS_EXCEPTION_IF_NULL(op_masks); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list); | |||||
| CNodePtr cnode = nullptr; | CNodePtr cnode = nullptr; | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| auto prim = op_exec_info->py_primitive; | auto prim = op_exec_info->py_primitive; | ||||
| const auto &signature = prim->signatures(); | |||||
| inputs.push_back(NewValueNode(prim)); | inputs.push_back(NewValueNode(prim)); | ||||
| size_t size = op_exec_info->op_inputs.size(); | size_t size = op_exec_info->op_inputs.size(); | ||||
| auto sig_size = signature.size(); | |||||
| // ignore signature for cast op | // ignore signature for cast op | ||||
| if (sig_size > 0 && sig_size != size) { | |||||
| MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires " | |||||
| << "inputs size " << sig_size; | |||||
| } | |||||
| bool is_cast_op = (op_exec_info->op_name == "Cast"); | bool is_cast_op = (op_exec_info->op_name == "Cast"); | ||||
| if (!is_cast_op) { | if (!is_cast_op) { | ||||
| const auto &signature = prim->signatures(); | |||||
| for (size_t i = 0; i < size; i++) { | for (size_t i = 0; i < size; i++) { | ||||
| auto obj = op_exec_info->op_inputs[i]; | auto obj = op_exec_info->op_inputs[i]; | ||||
| auto sig = SignatureEnumRW::kRWDefault; | auto sig = SignatureEnumRW::kRWDefault; | ||||
| if (signature.size() > 0) { | |||||
| if (sig_size > 0) { | |||||
| sig = signature[i].rw; | sig = signature[i].rw; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " " | MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " " | ||||
| @@ -455,7 +455,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| >>> data.set_dtype(mindspore.int32) | >>> data.set_dtype(mindspore.int32) | ||||
| mindspore.int32 | mindspore.int32 | ||||
| )mydelimiter") | )mydelimiter") | ||||
| .def("set_cast_dtype", &Tensor::set_cast_dtype) | |||||
| .def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr) | |||||
| .def("__str__", &Tensor::ToString) | .def("__str__", &Tensor::ToString) | ||||
| .def("__repr__", &Tensor::ToStringRepr) | .def("__repr__", &Tensor::ToStringRepr) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| @@ -292,7 +292,6 @@ class _PynativeExecutor: | |||||
| def __init__(self): | def __init__(self): | ||||
| self._executor = PynativeExecutor_.get_instance() | self._executor = PynativeExecutor_.get_instance() | ||||
| #TODO(kpy):add a type arg | |||||
| def new_graph(self, obj, *args, **kwargs): | def new_graph(self, obj, *args, **kwargs): | ||||
| self._executor.new_graph(obj, *args, *(kwargs.values())) | self._executor.new_graph(obj, *args, *(kwargs.values())) | ||||
| @@ -269,7 +269,7 @@ class Tensor : public MetaTensor { | |||||
| std::string id() const { return id_; } | std::string id() const { return id_; } | ||||
| TypePtr cast_dtype() { return cast_dtype_; } | TypePtr cast_dtype() { return cast_dtype_; } | ||||
| void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; } | |||||
| void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; } | |||||
| void SetNeedWait(bool need_wait) { | void SetNeedWait(bool need_wait) { | ||||
| if (event_ != nullptr) { | if (event_ != nullptr) { | ||||
| @@ -582,10 +582,13 @@ class Cell(Cell_): | |||||
| param (Parameter): The parameter to cast. | param (Parameter): The parameter to cast. | ||||
| """ | """ | ||||
| if hasattr(self, "_mindspore_flags"): | if hasattr(self, "_mindspore_flags"): | ||||
| if self._mindspore_flags.get('fp16'): | |||||
| param.set_cast_dtype(mstype.float16) | |||||
| if self._mindspore_flags.get('fp32'): | if self._mindspore_flags.get('fp32'): | ||||
| param.set_cast_dtype(mstype.float32) | param.set_cast_dtype(mstype.float32) | ||||
| elif self._mindspore_flags.get('fp16'): | |||||
| param.set_cast_dtype(mstype.float16) | |||||
| else: | |||||
| # retest dtype | |||||
| param.set_cast_dtype() | |||||
| return param | return param | ||||
| def insert_child_to_cell(self, child_name, child): | def insert_child_to_cell(self, child_name, child): | ||||
| @@ -464,7 +464,7 @@ raise_set = [ | |||||
| 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), | 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), | ||||
| 'desc_inputs': [0]}), | 'desc_inputs': [0]}), | ||||
| ('AssignAdd_Error', { | ('AssignAdd_Error', { | ||||
| 'block': (P.AssignAdd(), {'exception': IndexError}), | |||||
| 'block': (P.AssignAdd(), {'exception': ValueError}), | |||||
| 'desc_inputs': [[1]]}), | 'desc_inputs': [[1]]}), | ||||
| ] | ] | ||||