| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """builtin_operations""" | """builtin_operations""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | ||||
| @@ -171,3 +172,12 @@ def tuple_to_array(x): | |||||
| def stop_gradient(x): | def stop_gradient(x): | ||||
| """Implement `stop_gradient`.""" | """Implement `stop_gradient`.""" | ||||
| return x | return x | ||||
| def mixed_precision_cast(dst_type, x): | |||||
| """Implement `mixed_precision_cast`.""" | |||||
| if isinstance(x, tuple): | |||||
| res = list() | |||||
| for item in x: | |||||
| res.append(F.cast(item, dst_type)) | |||||
| return tuple(res) | |||||
| return F.cast(x, dst_type) | |||||
| @@ -61,7 +61,7 @@ struct OpExecInfo { | |||||
| using OpExecInfoPtr = std::shared_ptr<OpExecInfo>; | using OpExecInfoPtr = std::shared_ptr<OpExecInfo>; | ||||
| OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args); | OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args); | ||||
| const std::set<std::string> ignore_infer_prim = {"make_ref"}; | |||||
| const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; | |||||
| } // namespace pynative | } // namespace pynative | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy; | |||||
| const char SINGLE_OP_GRAPH[] = "single_op_graph"; | const char SINGLE_OP_GRAPH[] = "single_op_graph"; | ||||
| // primitive unable to infer value for constant input in PyNative mode | // primitive unable to infer value for constant input in PyNative mode | ||||
| const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; | |||||
| const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"}; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pynative { | namespace pynative { | ||||
| @@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } | |||||
| void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { | void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { | ||||
| auto cell_id = GetId(cell); | auto cell_id = GetId(cell); | ||||
| if (cell_graph_map_.count(cell_id) != 0) { | if (cell_graph_map_.count(cell_id) != 0) { | ||||
| if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { | |||||
| resource_ = cell_resource_map_[cell_id]; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Newgraph already compiled"; | MS_LOG(DEBUG) << "Newgraph already compiled"; | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg | |||||
| if (top_g_ == nullptr) { | if (top_g_ == nullptr) { | ||||
| top_g_ = curr_g_ = g; | top_g_ = curr_g_ = g; | ||||
| resource_ = std::make_shared<pipeline::Resource>(); | |||||
| cell_resource_map_[cell_id] = resource_; | |||||
| df_builder_ = std::make_shared<FuncGraph>(); | df_builder_ = std::make_shared<FuncGraph>(); | ||||
| MS_LOG(DEBUG) << "First new graph" << top_g_.get(); | MS_LOG(DEBUG) << "First new graph" << top_g_.get(); | ||||
| Pushp(); | Pushp(); | ||||
| @@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) { | |||||
| MS_LOG(DEBUG) << "Clear res"; | MS_LOG(DEBUG) << "Clear res"; | ||||
| (void)graph_map_.erase(flag); | (void)graph_map_.erase(flag); | ||||
| (void)cell_graph_map_.erase(flag); | (void)cell_graph_map_.erase(flag); | ||||
| (void)cell_resource_map_.erase(flag); | |||||
| Clean(); | Clean(); | ||||
| // Maybe exit in the pynative runing op, so need reset pynative flag. | // Maybe exit in the pynative runing op, so need reset pynative flag. | ||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| @@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) { | |||||
| MS_LOG(DEBUG) << "Clear"; | MS_LOG(DEBUG) << "Clear"; | ||||
| top_g_ = nullptr; | top_g_ = nullptr; | ||||
| df_builder_ = nullptr; | |||||
| curr_g_ = nullptr; | curr_g_ = nullptr; | ||||
| graph_info_map_.clear(); | graph_info_map_.clear(); | ||||
| op_id_map_.clear(); | op_id_map_.clear(); | ||||
| @@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() { | |||||
| Clear(); | Clear(); | ||||
| grad_flag_ = false; | grad_flag_ = false; | ||||
| op_forward_map_.clear(); | op_forward_map_.clear(); | ||||
| df_builder_ = nullptr; | |||||
| ad::CleanRes(); | ad::CleanRes(); | ||||
| pipeline::ReclaimOptimizer(); | pipeline::ReclaimOptimizer(); | ||||
| } | } | ||||
| @@ -119,6 +119,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| bool grad_flag_; | bool grad_flag_; | ||||
| std::unordered_map<std::string, FuncGraphPtr> graph_map_; | std::unordered_map<std::string, FuncGraphPtr> graph_map_; | ||||
| std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_; | std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_; | ||||
| std::unordered_map<std::string, ResourcePtr> cell_resource_map_; | |||||
| std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; | std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; | ||||
| std::unordered_map<std::string, ValuePtr> op_forward_map_; | std::unordered_map<std::string, ValuePtr> op_forward_map_; | ||||
| std::unordered_map<std::string, size_t> op_id_map_; | std::unordered_map<std::string, size_t> op_id_map_; | ||||
| @@ -240,12 +240,13 @@ class Cell: | |||||
| else: | else: | ||||
| _pynative_exec.set_grad_flag(False) | _pynative_exec.set_grad_flag(False) | ||||
| cast_inputs = list() | cast_inputs = list() | ||||
| if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): | |||||
| for item in inputs: | |||||
| cast_inputs.append(cast(item, mstype.float16)) | |||||
| if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): | |||||
| for item in inputs: | |||||
| cast_inputs.append(cast(item, mstype.float32)) | |||||
| if hasattr(self, "_mindspore_flags"): | |||||
| if self._mindspore_flags.get('fp16'): | |||||
| for item in inputs: | |||||
| cast_inputs.append(cast(item, mstype.float16)) | |||||
| if self._mindspore_flags.get('fp32'): | |||||
| for item in inputs: | |||||
| cast_inputs.append(cast(item, mstype.float32)) | |||||
| if cast_inputs: | if cast_inputs: | ||||
| cast_inputs = tuple(cast_inputs) | cast_inputs = tuple(cast_inputs) | ||||
| else: | else: | ||||
| @@ -496,10 +497,11 @@ class Cell: | |||||
| Args: | Args: | ||||
| param (Parameter): The parameter to cast. | param (Parameter): The parameter to cast. | ||||
| """ | """ | ||||
| if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): | |||||
| return cast(param, mstype.float16) | |||||
| if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): | |||||
| return cast(param, mstype.float32) | |||||
| if hasattr(self, "_mindspore_flags"): | |||||
| if self._mindspore_flags.get('fp16'): | |||||
| return cast(param, mstype.float16) | |||||
| if self._mindspore_flags.get('fp32'): | |||||
| return cast(param, mstype.float32) | |||||
| return param | return param | ||||
| def insert_child_to_cell(self, child_name, child): | def insert_child_to_cell(self, child_name, child): | ||||
| @@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | def __init__(self, network, optimizer, scale_update_cell=None): | ||||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | ||||
| self.network = network | self.network = network | ||||
| self.network.set_grad() | |||||
| self.network.add_flags(defer_inline=True) | self.network.add_flags(defer_inline=True) | ||||
| self.weights = optimizer.parameters | self.weights = optimizer.parameters | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| @@ -20,6 +20,7 @@ import mindspore as ms | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -638,3 +639,9 @@ def test_large_for_loop_with_continue_break(): | |||||
| t = Tensor(np.ones([2, 3], dtype=np.float32)) | t = Tensor(np.ones([2, 3], dtype=np.float32)) | ||||
| net = Net() | net = Net() | ||||
| net(t) | net(t) | ||||
| def test_mixed_precision_cast(): | |||||
| x = Tensor(np.ones([2, 3], dtype=np.float32)) | |||||
| z = F.mixed_precision_cast(mstype.float16, x) | |||||
| assert z.dtype == mstype.float16 | |||||