Merge pull request !1803 from Kang/mastertags/v0.5.0-beta
| @@ -242,6 +242,7 @@ const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | |||
| const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | |||
| const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||
| const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||
| const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||
| // Comm ops | |||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| @@ -251,6 +251,7 @@ extern const PrimitivePtr kPrimIs_; | |||
| extern const PrimitivePtr kPrimIsNot; | |||
| extern const PrimitivePtr kPrimInDict; | |||
| extern const PrimitivePtr kPrimNotInDict; | |||
| extern const PrimitivePtr kPrimMixedPrecisionCast; | |||
| // Comm ops | |||
| extern const PrimitivePtr kPrimMirror; | |||
| @@ -67,7 +67,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo | |||
| } else { | |||
| return param; | |||
| } | |||
| auto cast_helper = prim::GetPythonOps("_mp_cast_helper", "mindspore.ops.composite.base"); | |||
| auto cast_helper = prim::kPrimMixedPrecisionCast; | |||
| auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); | |||
| return cast; | |||
| } | |||
| @@ -147,9 +147,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c | |||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| if (!prim_->isa<prim::DoSignaturePrimitive>()) { | |||
| MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString(); | |||
| } | |||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||
| } | |||
| @@ -221,9 +218,6 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt | |||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||
| } | |||
| if (!prim_->isa<prim::UnpackGraphPrimitive>()) { | |||
| MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString(); | |||
| } | |||
| auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>(); | |||
| auto out_node = out_conf->node()->cast<CNodePtr>(); | |||
| @@ -267,6 +261,63 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt | |||
| return engine->ForwardConfig(out_conf, fn_conf); | |||
| } | |||
| AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type, | |||
| FuncGraphPtr func_graph) { | |||
| AnfNodePtr target_node = source_node; | |||
| if (node_type->isa<AbstractTensor>()) { | |||
| auto x = node_type->cast<AbstractTensorPtr>(); | |||
| if (x->element()->BuildType()->isa<Float>()) { | |||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||
| MS_EXCEPTION_IF_NULL(cast); | |||
| target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type}); | |||
| } | |||
| } else if (node_type->isa<AbstractTuple>()) { | |||
| auto x = node_type->cast<AbstractTuplePtr>(); | |||
| auto &items = x->elements(); | |||
| std::size_t size = items.size(); | |||
| std::vector<AnfNodePtr> nodes; | |||
| nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| for (int i = 0; i < SizeToInt(size); i++) { | |||
| AnfNodePtr tuple_node = | |||
| func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(i)}); | |||
| AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, items[i], target_type, func_graph); | |||
| nodes.emplace_back(node); | |||
| } | |||
| target_node = func_graph->NewCNode(nodes); | |||
| } | |||
| return target_node; | |||
| } | |||
| EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||
| } | |||
| auto out_node = out_conf->node()->cast<CNodePtr>(); | |||
| const auto &out_node_inputs = out_node->inputs(); | |||
| if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { | |||
| MS_LOG(EXCEPTION) << "MixedPrecisionCast" | |||
| << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() | |||
| << ", inputs size " << out_node_inputs.size(); | |||
| } | |||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||
| ScopePtr scope = kDefaultScope; | |||
| if (out_conf != nullptr) { | |||
| scope = out_conf->node()->scope(); | |||
| } | |||
| ScopeGuard scope_guard(scope); | |||
| FuncGraphPtr func_graph = out_conf->node()->func_graph(); | |||
| AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); | |||
| AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); | |||
| return engine->ForwardConfig(out_conf, fn_conf); | |||
| } | |||
| namespace { | |||
| py::object BuildValue(const ValuePtr &value_ptr) { | |||
| if (value_ptr == nullptr) { | |||
| @@ -102,6 +102,22 @@ class UnpackGraphEvaluator : public Evaluator { | |||
| PrimitivePtr prim_; | |||
| }; | |||
| class MixedPrecisionCastEvaluator : public Evaluator { | |||
| public: | |||
| explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive) | |||
| : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {} | |||
| ~MixedPrecisionCastEvaluator() override = default; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||
| AnfNodeConfigPtr out_config = nullptr) override; | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||
| } | |||
| private: | |||
| PrimitivePtr prim_; | |||
| }; | |||
| bool IsInWhiteList(PrimitivePtr primitive); | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||
| @@ -308,6 +308,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | |||
| return evaluator; | |||
| } | |||
| if (prim->name() == prim::kPrimMixedPrecisionCast->name()) { | |||
| evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||
| return evaluator; | |||
| } | |||
| if (prim->HasPyEvaluator()) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | |||
| if (prim_py != nullptr) { | |||
| @@ -21,7 +21,6 @@ from ...common.parameter import Parameter, ParameterTuple | |||
| from ...ops import composite as C | |||
| from ...ops import functional as F | |||
| from ...ops import operations as P | |||
| from ...ops.composite.base import _mp_cast_helper | |||
| from ...ops.operations.comm_ops import _VirtualDataset | |||
| from ..cell import Cell | |||
| from .grad_reducer import DistributedGradReducer | |||
| @@ -345,7 +344,7 @@ class WithEvalCell(Cell): | |||
| def construct(self, data, label): | |||
| outputs = self._network(data) | |||
| if self.add_cast_fp32: | |||
| label = _mp_cast_helper(mstype.float32, label) | |||
| label = F.mixed_precision_cast(mstype.float32, label) | |||
| outputs = F.cast(outputs, mstype.float32) | |||
| loss = self._loss_fn(outputs, label) | |||
| return loss, outputs, label | |||
| @@ -24,7 +24,6 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeF | |||
| from ...common import dtype as mstype | |||
| from ...common.api import ms_function, _pynative_exec | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| from ...common.parameter import Parameter | |||
| @@ -297,32 +296,3 @@ env_get = MultitypeFuncGraph("env_get") | |||
| def _tensor_env_get(env, parameter): | |||
| """Used to get env.""" | |||
| return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter)) | |||
| _mp_cast_helper = MultitypeFuncGraph('mixed_precision_cast_helper') | |||
| @_mp_cast_helper.register("TypeType", "Number") | |||
| @core | |||
| def _mixed_precision_cast_helper_1(type_, x): | |||
| """if x is float cast to type.""" | |||
| # type_ is place holder | |||
| return x | |||
| @_mp_cast_helper.register("TypeType", "Tensor") | |||
| @core | |||
| def _mixed_precision_cast_helper_2(type_, x): | |||
| """if x is float cast to type.""" | |||
| if F.issubclass_(F.dtype(x), mstype.float_): | |||
| return P.Cast()(x, type_) | |||
| return x | |||
| @_mp_cast_helper.register("TypeType", "Tuple") | |||
| @core | |||
| def _mixed_precision_cast_helper_3(type_, x): | |||
| """if x is a tuple""" | |||
| t = () | |||
| for item in x: | |||
| t = t + (_mp_cast_helper(type_, item),) | |||
| return t | |||
| @@ -126,6 +126,7 @@ is_ = Primitive("is_") | |||
| is_not = Primitive("is_not") | |||
| in_dict = Primitive("in_dict") | |||
| not_in_dict = Primitive("not_in_dict") | |||
| mixed_precision_cast = Primitive("mixed_precision_cast") | |||
| broadcast_gradient_args = Primitive('BroadcastGradientArgs') | |||
| dot = Primitive('dot') | |||
| array_reduce = Primitive('array_reduce') | |||
| @@ -21,7 +21,6 @@ from .._checkparam import Rel | |||
| from ..common import dtype as mstype | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from ..ops import functional as F | |||
| from ..ops.composite.base import _mp_cast_helper | |||
| from ..parallel._utils import _get_parallel_mode | |||
| from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager | |||
| from .parallel_utils import ParallelMode | |||
| @@ -98,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): | |||
| def construct(self, data, label): | |||
| out = self._backbone(data) | |||
| label = _mp_cast_helper(mstype.float32, label) | |||
| label = F.mixed_precision_cast(mstype.float32, label) | |||
| return self._loss_fn(F.cast(out, mstype.float32), label) | |||
| validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) | |||