Merge pull request !1803 from Kang/mastertags/v0.5.0-beta
| @@ -242,6 +242,7 @@ const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | |||||
| const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | ||||
| const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | ||||
| const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | ||||
| const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||||
| // Comm ops | // Comm ops | ||||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| @@ -251,6 +251,7 @@ extern const PrimitivePtr kPrimIs_; | |||||
| extern const PrimitivePtr kPrimIsNot; | extern const PrimitivePtr kPrimIsNot; | ||||
| extern const PrimitivePtr kPrimInDict; | extern const PrimitivePtr kPrimInDict; | ||||
| extern const PrimitivePtr kPrimNotInDict; | extern const PrimitivePtr kPrimNotInDict; | ||||
| extern const PrimitivePtr kPrimMixedPrecisionCast; | |||||
| // Comm ops | // Comm ops | ||||
| extern const PrimitivePtr kPrimMirror; | extern const PrimitivePtr kPrimMirror; | ||||
| @@ -67,7 +67,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo | |||||
| } else { | } else { | ||||
| return param; | return param; | ||||
| } | } | ||||
| auto cast_helper = prim::GetPythonOps("_mp_cast_helper", "mindspore.ops.composite.base"); | |||||
| auto cast_helper = prim::kPrimMixedPrecisionCast; | |||||
| auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); | auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); | ||||
| return cast; | return cast; | ||||
| } | } | ||||
| @@ -147,9 +147,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c | |||||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | ||||
| AnfNodeConfigPtr out_conf) { | AnfNodeConfigPtr out_conf) { | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| if (!prim_->isa<prim::DoSignaturePrimitive>()) { | |||||
| MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString(); | |||||
| } | |||||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | ||||
| } | } | ||||
| @@ -221,9 +218,6 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt | |||||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | ||||
| } | } | ||||
| if (!prim_->isa<prim::UnpackGraphPrimitive>()) { | |||||
| MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString(); | |||||
| } | |||||
| auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>(); | auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>(); | ||||
| auto out_node = out_conf->node()->cast<CNodePtr>(); | auto out_node = out_conf->node()->cast<CNodePtr>(); | ||||
| @@ -267,6 +261,63 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt | |||||
| return engine->ForwardConfig(out_conf, fn_conf); | return engine->ForwardConfig(out_conf, fn_conf); | ||||
| } | } | ||||
| AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type, | |||||
| FuncGraphPtr func_graph) { | |||||
| AnfNodePtr target_node = source_node; | |||||
| if (node_type->isa<AbstractTensor>()) { | |||||
| auto x = node_type->cast<AbstractTensorPtr>(); | |||||
| if (x->element()->BuildType()->isa<Float>()) { | |||||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||||
| MS_EXCEPTION_IF_NULL(cast); | |||||
| target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type}); | |||||
| } | |||||
| } else if (node_type->isa<AbstractTuple>()) { | |||||
| auto x = node_type->cast<AbstractTuplePtr>(); | |||||
| auto &items = x->elements(); | |||||
| std::size_t size = items.size(); | |||||
| std::vector<AnfNodePtr> nodes; | |||||
| nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| for (int i = 0; i < SizeToInt(size); i++) { | |||||
| AnfNodePtr tuple_node = | |||||
| func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(i)}); | |||||
| AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, items[i], target_type, func_graph); | |||||
| nodes.emplace_back(node); | |||||
| } | |||||
| target_node = func_graph->NewCNode(nodes); | |||||
| } | |||||
| return target_node; | |||||
| } | |||||
| EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| AbstractBasePtrList args_spec_list; | |||||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||||
| } | |||||
| auto out_node = out_conf->node()->cast<CNodePtr>(); | |||||
| const auto &out_node_inputs = out_node->inputs(); | |||||
| if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { | |||||
| MS_LOG(EXCEPTION) << "MixedPrecisionCast" | |||||
| << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() | |||||
| << ", inputs size " << out_node_inputs.size(); | |||||
| } | |||||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | |||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||||
| ScopePtr scope = kDefaultScope; | |||||
| if (out_conf != nullptr) { | |||||
| scope = out_conf->node()->scope(); | |||||
| } | |||||
| ScopeGuard scope_guard(scope); | |||||
| FuncGraphPtr func_graph = out_conf->node()->func_graph(); | |||||
| AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); | |||||
| AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); | |||||
| return engine->ForwardConfig(out_conf, fn_conf); | |||||
| } | |||||
| namespace { | namespace { | ||||
| py::object BuildValue(const ValuePtr &value_ptr) { | py::object BuildValue(const ValuePtr &value_ptr) { | ||||
| if (value_ptr == nullptr) { | if (value_ptr == nullptr) { | ||||
| @@ -102,6 +102,22 @@ class UnpackGraphEvaluator : public Evaluator { | |||||
| PrimitivePtr prim_; | PrimitivePtr prim_; | ||||
| }; | }; | ||||
| class MixedPrecisionCastEvaluator : public Evaluator { | |||||
| public: | |||||
| explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive) | |||||
| : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {} | |||||
| ~MixedPrecisionCastEvaluator() override = default; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||||
| AnfNodeConfigPtr out_config = nullptr) override; | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||||
| } | |||||
| private: | |||||
| PrimitivePtr prim_; | |||||
| }; | |||||
| bool IsInWhiteList(PrimitivePtr primitive); | bool IsInWhiteList(PrimitivePtr primitive); | ||||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | ||||
| @@ -308,6 +308,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||||
| evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | ||||
| return evaluator; | return evaluator; | ||||
| } | } | ||||
| if (prim->name() == prim::kPrimMixedPrecisionCast->name()) { | |||||
| evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||||
| return evaluator; | |||||
| } | |||||
| if (prim->HasPyEvaluator()) { | if (prim->HasPyEvaluator()) { | ||||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | auto prim_py = dyn_cast<PrimitivePy>(prim); | ||||
| if (prim_py != nullptr) { | if (prim_py != nullptr) { | ||||
| @@ -21,7 +21,6 @@ from ...common.parameter import Parameter, ParameterTuple | |||||
| from ...ops import composite as C | from ...ops import composite as C | ||||
| from ...ops import functional as F | from ...ops import functional as F | ||||
| from ...ops import operations as P | from ...ops import operations as P | ||||
| from ...ops.composite.base import _mp_cast_helper | |||||
| from ...ops.operations.comm_ops import _VirtualDataset | from ...ops.operations.comm_ops import _VirtualDataset | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from .grad_reducer import DistributedGradReducer | from .grad_reducer import DistributedGradReducer | ||||
| @@ -345,7 +344,7 @@ class WithEvalCell(Cell): | |||||
| def construct(self, data, label): | def construct(self, data, label): | ||||
| outputs = self._network(data) | outputs = self._network(data) | ||||
| if self.add_cast_fp32: | if self.add_cast_fp32: | ||||
| label = _mp_cast_helper(mstype.float32, label) | |||||
| label = F.mixed_precision_cast(mstype.float32, label) | |||||
| outputs = F.cast(outputs, mstype.float32) | outputs = F.cast(outputs, mstype.float32) | ||||
| loss = self._loss_fn(outputs, label) | loss = self._loss_fn(outputs, label) | ||||
| return loss, outputs, label | return loss, outputs, label | ||||
| @@ -24,7 +24,6 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeF | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.api import ms_function, _pynative_exec | from ...common.api import ms_function, _pynative_exec | ||||
| from .. import functional as F | from .. import functional as F | ||||
| from .. import operations as P | |||||
| from ...common.parameter import Parameter | from ...common.parameter import Parameter | ||||
| @@ -297,32 +296,3 @@ env_get = MultitypeFuncGraph("env_get") | |||||
| def _tensor_env_get(env, parameter): | def _tensor_env_get(env, parameter): | ||||
| """Used to get env.""" | """Used to get env.""" | ||||
| return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter)) | return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter)) | ||||
| _mp_cast_helper = MultitypeFuncGraph('mixed_precision_cast_helper') | |||||
| @_mp_cast_helper.register("TypeType", "Number") | |||||
| @core | |||||
| def _mixed_precision_cast_helper_1(type_, x): | |||||
| """if x is float cast to type.""" | |||||
| # type_ is place holder | |||||
| return x | |||||
| @_mp_cast_helper.register("TypeType", "Tensor") | |||||
| @core | |||||
| def _mixed_precision_cast_helper_2(type_, x): | |||||
| """if x is float cast to type.""" | |||||
| if F.issubclass_(F.dtype(x), mstype.float_): | |||||
| return P.Cast()(x, type_) | |||||
| return x | |||||
| @_mp_cast_helper.register("TypeType", "Tuple") | |||||
| @core | |||||
| def _mixed_precision_cast_helper_3(type_, x): | |||||
| """if x is a tuple""" | |||||
| t = () | |||||
| for item in x: | |||||
| t = t + (_mp_cast_helper(type_, item),) | |||||
| return t | |||||
| @@ -126,6 +126,7 @@ is_ = Primitive("is_") | |||||
| is_not = Primitive("is_not") | is_not = Primitive("is_not") | ||||
| in_dict = Primitive("in_dict") | in_dict = Primitive("in_dict") | ||||
| not_in_dict = Primitive("not_in_dict") | not_in_dict = Primitive("not_in_dict") | ||||
| mixed_precision_cast = Primitive("mixed_precision_cast") | |||||
| broadcast_gradient_args = Primitive('BroadcastGradientArgs') | broadcast_gradient_args = Primitive('BroadcastGradientArgs') | ||||
| dot = Primitive('dot') | dot = Primitive('dot') | ||||
| array_reduce = Primitive('array_reduce') | array_reduce = Primitive('array_reduce') | ||||
| @@ -21,7 +21,6 @@ from .._checkparam import Rel | |||||
| from ..common import dtype as mstype | from ..common import dtype as mstype | ||||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | ||||
| from ..ops import functional as F | from ..ops import functional as F | ||||
| from ..ops.composite.base import _mp_cast_helper | |||||
| from ..parallel._utils import _get_parallel_mode | from ..parallel._utils import _get_parallel_mode | ||||
| from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager | from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager | ||||
| from .parallel_utils import ParallelMode | from .parallel_utils import ParallelMode | ||||
| @@ -98,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): | |||||
| def construct(self, data, label): | def construct(self, data, label): | ||||
| out = self._backbone(data) | out = self._backbone(data) | ||||
| label = _mp_cast_helper(mstype.float32, label) | |||||
| label = F.mixed_precision_cast(mstype.float32, label) | |||||
| return self._loss_fn(F.cast(out, mstype.float32), label) | return self._loss_fn(F.cast(out, mstype.float32), label) | ||||
| validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) | validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) | ||||