| @@ -62,8 +62,6 @@ ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalar | |||||
| {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, | {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, | ||||
| {"__ge__", kPrimScalarGe}}; | {"__ge__", kPrimScalarGe}}; | ||||
| const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail"); | |||||
| // copy from python API: reduce. | // copy from python API: reduce. | ||||
| // Apply a function of two arguments cumulatively to the items of a sequence, | // Apply a function of two arguments cumulatively to the items of a sequence, | ||||
| // from left to right, so as to reduce the sequence to a single value.For example, | // from left to right, so as to reduce the sequence to a single value.For example, | ||||
| @@ -384,8 +382,8 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { | |||||
| .def(py::init<>()); | .def(py::init<>()); | ||||
| })); | })); | ||||
| FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { | |||||
| MS_EXCEPTION_IF_NULL(a_tuple); | |||||
| FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) { | |||||
| MS_EXCEPTION_IF_NULL(sequeue); | |||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | ||||
| @@ -393,31 +391,24 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu | |||||
| AnfNodePtr ptrTup = ret->add_parameter(); | AnfNodePtr ptrTup = ret->add_parameter(); | ||||
| std::vector<AnfNodePtr> elems; | std::vector<AnfNodePtr> elems; | ||||
| elems.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| int64_t tuple_size = SizeToLong(a_tuple->size()); | |||||
| for (int64_t i = 1; i < tuple_size; ++i) { | |||||
| elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)})); | |||||
| PrimitivePtr op = nullptr; | |||||
| if (sequeue->isa<AbstractTuple>()) { | |||||
| elems.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| op = prim::kPrimTupleGetItem; | |||||
| } else { | |||||
| elems.push_back(NewValueNode(prim::kPrimMakeList)); | |||||
| op = prim::kPrimListGetItem; | |||||
| } | } | ||||
| ret->set_output(ret->NewCNode(elems)); | |||||
| return ret; | |||||
| } | |||||
| FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { | |||||
| MS_EXCEPTION_IF_NULL(a_list); | |||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->debug_info()->set_name("tail"); | |||||
| AnfNodePtr ptrList = ret->add_parameter(); | |||||
| std::vector<AnfNodePtr> elems; | |||||
| elems.push_back(NewValueNode(prim::kPrimMakeList)); | |||||
| int64_t list_size = SizeToLong(a_list->size()); | |||||
| for (int64_t i = 1; i < list_size; ++i) { | |||||
| elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)})); | |||||
| for (size_t i = 1; i < sequeue->size(); ++i) { | |||||
| if (do_grad_) { | |||||
| MS_EXCEPTION_IF_NULL((*sequeue)[i]); | |||||
| if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>()) { | |||||
| elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | |||||
| } | |||||
| } else { | |||||
| elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | |||||
| } | |||||
| } | } | ||||
| ret->set_output(ret->NewCNode(elems)); | ret->set_output(ret->NewCNode(elems)); | ||||
| @@ -430,14 +421,8 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) | |||||
| } | } | ||||
| AbstractBasePtr a = args_spec_list[0]; | AbstractBasePtr a = args_spec_list[0]; | ||||
| abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(a); | |||||
| if (a_tuple != nullptr) { | |||||
| return GenerateTupleFuncGraph(a_tuple); | |||||
| } | |||||
| abstract::AbstractListPtr a_list = dyn_cast<AbstractList>(a); | |||||
| if (a_list != nullptr) { | |||||
| return GenerateListFuncGraph(a_list); | |||||
| if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) { | |||||
| return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>()); | |||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); | MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); | ||||
| @@ -614,7 +599,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An | |||||
| CNodePtr inputs_bprop = nullptr; | CNodePtr inputs_bprop = nullptr; | ||||
| if (get_all_) { | if (get_all_) { | ||||
| inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp}); | |||||
| TailPtr tail = std::make_shared<Tail>("tail", true); | |||||
| inputs_bprop = func_graph->NewCNode({NewValueNode(tail), ptr_bapp}); | |||||
| } | } | ||||
| // Gradients wrt inputs and parameters | // Gradients wrt inputs and parameters | ||||
| @@ -99,15 +99,17 @@ extern ValuePtr kCompositeHyperMap; | |||||
| class Tail : public MetaFuncGraph { | class Tail : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit Tail(const std::string &name) : MetaFuncGraph(name) {} | |||||
| explicit Tail(const std::string &name, bool do_grad = false) : MetaFuncGraph(name), do_grad_(do_grad) {} | |||||
| ~Tail() override = default; | ~Tail() override = default; | ||||
| MS_DECLARE_PARENT(Tail, MetaFuncGraph) | MS_DECLARE_PARENT(Tail, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | ||||
| FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); | |||||
| FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); | |||||
| FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue); | |||||
| friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } | friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } | ||||
| private: | |||||
| bool do_grad_; | |||||
| }; | }; | ||||
| using TailPtr = std::shared_ptr<Tail>; | using TailPtr = std::shared_ptr<Tail>; | ||||
| @@ -446,10 +446,28 @@ bool TransformTopGraphPass(const ResourcePtr &res) { | |||||
| bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); } | bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); } | ||||
| void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> new_paras; | |||||
| for (const auto ¶m : func_graph->parameters()) { | |||||
| auto param_node = param->cast<ParameterPtr>(); | |||||
| if (param_node->has_default()) { | |||||
| new_paras.push_back(param_node); | |||||
| continue; | |||||
| } | |||||
| AbstractBasePtr par_abs = param_node->abstract(); | |||||
| if (par_abs->isa<abstract::AbstractUndetermined>()) { | |||||
| new_paras.push_back(param_node); | |||||
| } | |||||
| } | |||||
| func_graph->set_parameters(new_paras); | |||||
| } | |||||
| bool ValidatePass(const ResourcePtr &res) { | bool ValidatePass(const ResourcePtr &res) { | ||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | MS_EXCEPTION_IF_NULL(res->func_graph()); | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| Validate(func_graph); | Validate(func_graph); | ||||
| UpdateFuncGraphParameter(func_graph); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -69,6 +69,10 @@ namespace pipeline { | |||||
| using Tensor = mindspore::tensor::Tensor; | using Tensor = mindspore::tensor::Tensor; | ||||
| using MetaTensor = mindspore::tensor::MetaTensor; | using MetaTensor = mindspore::tensor::MetaTensor; | ||||
| using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>; | using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>; | ||||
| using mindspore::abstract::AbstractDictionary; | |||||
| using mindspore::abstract::AbstractDictionaryPtr; | |||||
| using mindspore::abstract::AbstractList; | |||||
| using mindspore::abstract::AbstractListPtr; | |||||
| using mindspore::abstract::AbstractTensor; | using mindspore::abstract::AbstractTensor; | ||||
| using mindspore::abstract::AbstractTensorPtr; | using mindspore::abstract::AbstractTensorPtr; | ||||
| using mindspore::abstract::AbstractTuple; | using mindspore::abstract::AbstractTuple; | ||||
| @@ -93,15 +97,10 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto tensor_arg = arg->cast<TensorPtr>(); | |||||
| if (tensor_arg == nullptr) { | |||||
| MS_EXCEPTION(TypeError) << "For 'graph mode', the " << idx << "th arg: " << arg->ToString() << " is not a tensor."; | |||||
| } | |||||
| if (tensor_arg->is_parameter()) { | |||||
| MS_EXCEPTION(TypeError) << "The inputs could not be Parameter."; | |||||
| } | |||||
| AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| bool broaden = value->isa<MetaTensor>(); | |||||
| return abstract::FromValue(value, broaden); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -117,8 +116,7 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str | |||||
| if (!parse::ConvertData(arg.second, &converted)) { | if (!parse::ConvertData(arg.second, &converted)) { | ||||
| MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; | MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; | ||||
| } | } | ||||
| bool broaden = converted->isa<Tensor>() || converted->isa<MetaTensor>(); | |||||
| args_spec.push_back(abstract::FromValue(converted, broaden)); | |||||
| args_spec.push_back(ArgsToAbstract(converted)); | |||||
| } | } | ||||
| if (g_args_cache.count(args_spec) == 0) { | if (g_args_cache.count(args_spec) == 0) { | ||||
| static int64_t key = 0; | static int64_t key = 0; | ||||
| @@ -484,11 +482,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons | |||||
| if (!succ) { | if (!succ) { | ||||
| MS_LOG(EXCEPTION) << "Args convert error"; | MS_LOG(EXCEPTION) << "Args convert error"; | ||||
| } | } | ||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||||
| CheckArgIsTensor(converted, i); | |||||
| } | |||||
| bool broaden = true; | |||||
| args_spec.push_back(abstract::FromValue(converted, broaden)); | |||||
| args_spec.push_back(ArgsToAbstract(converted)); | |||||
| } | } | ||||
| resource->set_args_spec(args_spec); | resource->set_args_spec(args_spec); | ||||
| @@ -814,9 +808,6 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { | |||||
| if (!parse::ConvertData(args[i], &converted)) { | if (!parse::ConvertData(args[i], &converted)) { | ||||
| MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; | MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; | ||||
| } | } | ||||
| if (!converted->isa<tensor::Tensor>()) { | |||||
| MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor."; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return *ret_val; | return *ret_val; | ||||
| @@ -208,7 +208,11 @@ class _MindSporeFunction: | |||||
| if context.get_context("precompile_only"): | if context.get_context("precompile_only"): | ||||
| return None | return None | ||||
| return self._executor(args_list, phase) | |||||
| new_inputs = [] | |||||
| for i in args_list: | |||||
| if isinstance(i, Tensor): | |||||
| new_inputs.append(i) | |||||
| return self._executor(tuple(new_inputs), phase) | |||||
| def ms_function(fn=None, obj=None, input_signature=None): | def ms_function(fn=None, obj=None, input_signature=None): | ||||
| @@ -18,7 +18,6 @@ | |||||
| from ...composite import base | from ...composite import base | ||||
| from ... import functional as F | from ... import functional as F | ||||
| zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True) | zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True) | ||||
| """ | """ | ||||
| `zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type | `zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type | ||||
| @@ -31,11 +30,13 @@ def _zeros_like_scala(x): | |||||
| """Returns 0 which has the same dtype as x where x is a scalar.""" | """Returns 0 which has the same dtype as x where x is a scalar.""" | ||||
| return 0 | return 0 | ||||
| @zeros_like_leaf.register("Bool") | @zeros_like_leaf.register("Bool") | ||||
| def _zeros_like_bool(x): | def _zeros_like_bool(x): | ||||
| """Returns False if x is a bool.""" | """Returns False if x is a bool.""" | ||||
| return False | return False | ||||
| newenv = base.EnvInstance_() | newenv = base.EnvInstance_() | ||||
| @@ -100,6 +101,25 @@ def _zeros_like_abstract_error(x): | |||||
| return x | return x | ||||
| @zeros_like_leaf.register("Dictionary") | |||||
| def _zeros_like_dict(x): | |||||
| """ | |||||
| Derivation of a AbstractError. | |||||
| Args: | |||||
| x (dict): the input | |||||
| Returns: | |||||
| dict, keys are same as input's keys, and value are same as zeros_like of input'value. | |||||
| """ | |||||
| keys = x.keys() | |||||
| values = x.values() | |||||
| new_values = () | |||||
| for ele in values: | |||||
| new_values += (zeros_like_leaf(ele),) | |||||
| return F.make_dict(keys, new_values) | |||||
| # zeros_like is an object that will generate graph of zero_like operation for different type | # zeros_like is an object that will generate graph of zero_like operation for different type | ||||
| zeros_like = base.HyperMap(zeros_like_leaf) | zeros_like = base.HyperMap(zeros_like_leaf) | ||||
| """`zeros_like` is an object that will generate graph of `zero_like` operation for different type.""" | """`zeros_like` is an object that will generate graph of `zero_like` operation for different type.""" | ||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test ms_function pass non_tensor inputs""" | |||||
| import numpy as np | |||||
| from mindspore import Tensor, ms_function, Parameter | |||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) | |||||
| @ms_function | |||||
| def compute(x, y, tuple_p, list_q, dict_w): | |||||
| return x + y - tuple_p[0] + list_q[1] - dict_w["x"] | |||||
| def test_scalar_compute(): | |||||
| int_x = 1 | |||||
| int_y = 2 | |||||
| p = (3, 4) | |||||
| q = [5, 6] | |||||
| w = {"x": 7, "y": 8} | |||||
| ret = compute(int_x, int_y, p, q, w) | |||||
| assert ret == -1 | |||||
| def test_tensor_compute(): | |||||
| tensor_x = Tensor(np.ones((2, 3, 4), np.float32)) | |||||
| tensor_y = Tensor(np.ones((2, 3, 4), np.float32) * 2) | |||||
| p = (Tensor(np.ones((2, 3, 4), np.float32) * 3), Tensor(np.ones((2, 3, 4), np.float32) * 4)) | |||||
| q = [Tensor(np.ones((2, 3, 4), np.float32) * 5), Tensor(np.ones((2, 3, 4), np.float32) * 6)] | |||||
| w = {"x": Tensor(np.ones((2, 3, 4), np.float32) * 7), "y": Tensor(np.ones((2, 3, 4), np.float32) * 8)} | |||||
| compute(tensor_x, tensor_y, p, q, w) | |||||
| @ms_function | |||||
| def tensor_reduce(tensor_x, axis, tensor_y): | |||||
| reduce_sum = P.ReduceSum() | |||||
| ret = reduce_sum(tensor_x, axis) + tensor_y | |||||
| return ret | |||||
| def test_tensor_reduce(): | |||||
| tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32)) | |||||
| axis = (0, 1) | |||||
| tensor_y = Parameter(Tensor(np.ones((4, 5), np.float32) * 2)) | |||||
| tensor_reduce(tensor_x, axis, tensor_y) | |||||
| @@ -12,8 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test outermost net pass scalar tuple list dict""" | |||||
| import pytest | |||||
| """ test outermost net pass non_tensor inputs""" | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -28,7 +27,7 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): | |||||
| class TestNet(nn.Cell): | class TestNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(TestNet, self).__init__() | super(TestNet, self).__init__() | ||||
| self.support_non_tensor_inputs = True | |||||
| self.support_non_tensor_inputs = False | |||||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | def construct(self, tuple_a, z, list_m, w, s, dict_n): | ||||
| return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | ||||
| @@ -58,12 +57,5 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): | |||||
| forward_net(arg_t1, z, arg_l1, x, 6, args_d1) | forward_net(arg_t1, z, arg_l1, x, 6, args_d1) | ||||
| grad_net = GradNet(forward_net) | grad_net = GradNet(forward_net) | ||||
| with pytest.raises(TypeError) as err: | |||||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||||
| assert "For 'graph mode', the 0th arg" in str(err.value) | |||||
| grad_net.support_non_tensor_inputs = True | |||||
| with pytest.raises(ValueError) as err: | |||||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||||
| assert "Not support set 'support_non_tensor_inputs' to the 'True' for grad net, only support forward net." \ | |||||
| in str(err.value) | |||||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||||
| grad_net(arg_t1, z, arg_l1, x, 6, args_d1) | |||||
| @@ -20,9 +20,9 @@ import mindspore.ops.operations as P | |||||
| from mindspore import Tensor, context | from mindspore import Tensor, context | ||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import functional as F | |||||
| from ...ut_filter import non_graph_engine | from ...ut_filter import non_graph_engine | ||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||
| def setup_module(module): | def setup_module(module): | ||||
| context.set_context(mode=context.PYNATIVE_MODE) | context.set_context(mode=context.PYNATIVE_MODE) | ||||
| @@ -86,24 +86,6 @@ def test_cast_grad(): | |||||
| assert np.all(gout[0].asnumpy() == expect) | assert np.all(gout[0].asnumpy() == expect) | ||||
| def test_scalar_cast_grad(): | |||||
| """ test_scalar_cast_grad """ | |||||
| input_x = 255.5 | |||||
| input_t = ms.int8 | |||||
| def fx_cast(x): | |||||
| output = F.scalar_cast(x, input_t) | |||||
| return output | |||||
| @ms_function | |||||
| def grad_fx_cast(input_x): | |||||
| return grad(fx_cast)(input_x) | |||||
| gfn = grad_fx_cast(input_x) | |||||
| expect_dx = 1 | |||||
| assert gfn == expect_dx | |||||
| @non_graph_engine | @non_graph_engine | ||||
| def test_reshape_grad(): | def test_reshape_grad(): | ||||
| """ test_reshape_grad """ | """ test_reshape_grad """ | ||||
| @@ -14,31 +14,28 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test_framstruct """ | """ test_framstruct """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.common.parameter import Parameter, ParameterTuple | from mindspore.common.parameter import Parameter, ParameterTuple | ||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| from ....mindspore_test_framework.utils.check_gradient import ( | from ....mindspore_test_framework.utils.check_gradient import ( | ||||
| ms_function, check_jacobian, Tensor, NNGradChecker, | ms_function, check_jacobian, Tensor, NNGradChecker, | ||||
| OperationGradChecker, check_gradient, ScalarGradChecker) | |||||
| OperationGradChecker, check_gradient) | |||||
| context.set_context(mode=context.PYNATIVE_MODE) | context.set_context(mode=context.PYNATIVE_MODE) | ||||
| def setup_module(module): | def setup_module(module): | ||||
| context.set_context(mode=context.PYNATIVE_MODE) | context.set_context(mode=context.PYNATIVE_MODE) | ||||
| grad = C.GradOperation() | |||||
| grad_all = C.GradOperation(get_all=True) | grad_all = C.GradOperation(get_all=True) | ||||
| grad_by_list = C.GradOperation(get_by_list=True) | grad_by_list = C.GradOperation(get_by_list=True) | ||||
| grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) | |||||
| @ms_function | @ms_function | ||||
| @@ -79,9 +76,7 @@ def dynamic_make_tuple(x, lower, upper): | |||||
| def test_dynamic_make_tuple(): | def test_dynamic_make_tuple(): | ||||
| # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language. | |||||
| with pytest.raises(RuntimeError): | |||||
| dynamic_make_tuple(2, 1, 5) | |||||
| assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) | |||||
| def test_make_tuple(): | def test_make_tuple(): | ||||
| @@ -273,15 +268,6 @@ def rec(x): | |||||
| return rec(x - 1) | return rec(x - 1) | ||||
| return x | return x | ||||
| @ms_function | |||||
| def grad_rec(input_x): | |||||
| return grad(rec)(input_x) | |||||
| def test_grad_rec(): | |||||
| """ test_grad_rec """ | |||||
| res = grad_rec(3) | |||||
| assert res == 1 | |||||
| def test_me_rec(): | def test_me_rec(): | ||||
| """ test_me_rec """ | """ test_me_rec """ | ||||
| @@ -303,13 +289,6 @@ def test_while2(): | |||||
| assert res == 6 | assert res == 6 | ||||
| def test_grad_while2(): | |||||
| @ms_function | |||||
| def df_t2_while(input_x, input_y): | |||||
| return grad(t2_while)(input_x, input_y) | |||||
| assert df_t2_while(2, 3) == 3 | |||||
| def if_test(a, b): | def if_test(a, b): | ||||
| """ if_test """ | """ if_test """ | ||||
| if a > b: | if a > b: | ||||
| @@ -327,24 +306,6 @@ def test_grad_if(): | |||||
| assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) | assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) | ||||
| # While loop is not unrolled in forward and backward graphs. | |||||
| def test_dont_unroll_while(): | |||||
| def dont_unroll_while(x, y): | |||||
| i = 2 | |||||
| out = y - x | |||||
| while i < 10: | |||||
| out = mul(x, y) | |||||
| i = i + 1 | |||||
| return out | |||||
| @ms_function() | |||||
| def invoke_while(x, y): | |||||
| return grad(dont_unroll_while)(x, y) | |||||
| res = invoke_while(2, 3) | |||||
| assert res == 3 | |||||
| class ConvNet(nn.Cell): | class ConvNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ConvNet, self).__init__() | super(ConvNet, self).__init__() | ||||
| @@ -445,13 +406,6 @@ def test_factorial(): | |||||
| assert res == 6 | assert res == 6 | ||||
| def test_grad_factorial(): | |||||
| @ms_function | |||||
| def df_factorial(x): | |||||
| return grad(factorial)(x) | |||||
| assert df_factorial(3) == 11 | |||||
| @ms_function | @ms_function | ||||
| def factorial2(n): | def factorial2(n): | ||||
| """ factorial """ | """ factorial """ | ||||
| @@ -523,17 +477,13 @@ def _for(x): | |||||
| ret = ret * i | ret = ret * i | ||||
| return ret | return ret | ||||
| @ms_function | @ms_function | ||||
| def grad_for(x): | def grad_for(x): | ||||
| """ grad_for """ | """ grad_for """ | ||||
| return grad_all(_for)(x) | return grad_all(_for)(x) | ||||
| def test_grad_for(): | |||||
| """ test_grad_for """ | |||||
| assert grad_for(5) == (60,) | |||||
| @ms_function | @ms_function | ||||
| def try_tail(x): | def try_tail(x): | ||||
| """ try_tail """ | """ try_tail """ | ||||
| @@ -675,15 +625,6 @@ def test_arithmetic_simplify_08(): | |||||
| assert np.all(res.asnumpy() == expect) | assert np.all(res.asnumpy() == expect) | ||||
| def test_ScalarGradChecker(): | |||||
| """ test_ScalarGradChecker """ | |||||
| def scalar_f(x, y): | |||||
| return x * y | |||||
| check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1) | |||||
| def test_GradCheckerPrimitive(): | def test_GradCheckerPrimitive(): | ||||
| """ test_GradCheckerPrimitive """ | """ test_GradCheckerPrimitive """ | ||||
| matmul = P.MatMul() | matmul = P.MatMul() | ||||
| @@ -737,15 +678,6 @@ def test_OperationGradChecker(): | |||||
| input_selector=[1], sampling_times=2) | input_selector=[1], sampling_times=2) | ||||
| def test_ScalarJacobianChecker(): | |||||
| """ test_ScalarJacobianChecker """ | |||||
| def scalar_f(x, y): | |||||
| return x * y | |||||
| check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0]) | |||||
| def test_OperationJacobianChecker(): | def test_OperationJacobianChecker(): | ||||
| """ test_OperationJacobianChecker """ | """ test_OperationJacobianChecker """ | ||||
| @@ -795,13 +727,6 @@ def multi_outputs(x, y): | |||||
| return 2 * z, 2 * z | return 2 * z, 2 * z | ||||
| def test_grad_multi_outputs(): | |||||
| @ms_function | |||||
| def df_multi_outputs(x, y): | |||||
| return grad_all_with_sens(multi_outputs)(x, y, (1, 1)) | |||||
| assert df_multi_outputs(2, 3) == (4, 4) | |||||
| @ms_function | @ms_function | ||||
| def while_sp(x, y, z): | def while_sp(x, y, z): | ||||
| out = x | out = x | ||||
| @@ -874,13 +799,6 @@ def grad_refactor_3(a): | |||||
| return 3 * a | return 3 * a | ||||
| def test_grad_refactor_3(): | |||||
| @ms_function | |||||
| def df_refactor_3(x): | |||||
| return grad_all(grad_refactor_3)(x) | |||||
| assert df_refactor_3(3) == (3,) | |||||
| def grad_refactor_4(a): | def grad_refactor_4(a): | ||||
| """ if_test """ | """ if_test """ | ||||
| if a > 3: | if a > 3: | ||||
| @@ -899,13 +817,6 @@ def grad_refactor_5(a): | |||||
| return a | return a | ||||
| def test_grad_refactor_5(): | |||||
| @ms_function | |||||
| def df_refactor_5(x): | |||||
| return grad_all(grad_refactor_5)(x) | |||||
| assert df_refactor_5(1) == (1,) | |||||
| def grad_refactor_6(a, b): | def grad_refactor_6(a, b): | ||||
| """ if_test """ | """ if_test """ | ||||
| if a > b: | if a > b: | ||||
| @@ -925,13 +836,6 @@ def grad_refactor_while(x): | |||||
| return rval | return rval | ||||
| def test_grad_refactor_9(): | |||||
| @ms_function | |||||
| def df_refactor_while(input_x): | |||||
| return grad_all(grad_refactor_while)(input_x) | |||||
| assert df_refactor_while(3) == (6,) | |||||
| def grad_refactor__while_1(x): | def grad_refactor__while_1(x): | ||||
| """ _while """ | """ _while """ | ||||
| ret = x * x | ret = x * x | ||||
| @@ -1009,13 +913,6 @@ def grad_refactor_14(a, b): | |||||
| return inner1(b) + inner2(a) + inner3(a) | return inner1(b) + inner2(a) + inner3(a) | ||||
| def test_grad_refactor_14(): | |||||
| @ms_function | |||||
| def df_refactor_14(x, y): | |||||
| return grad_all(grad_refactor_14)(x, y) | |||||
| assert df_refactor_14(2, 3) == (3, 9) | |||||
| # pylint: disable=using-constant-test | # pylint: disable=using-constant-test | ||||
| class IfDeferInline(nn.Cell): | class IfDeferInline(nn.Cell): | ||||
| def __init__(self, mul_size): | def __init__(self, mul_size): | ||||
| @@ -1044,6 +941,8 @@ def test_dict_const(): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| self.res = {'1': 10} | self.res = {'1': 10} | ||||
| def construct(self): | def construct(self): | ||||
| return self.res | return self.res | ||||
| Net()() | Net()() | ||||
| @@ -109,25 +109,3 @@ def first_derivative_if(x): | |||||
| def second_derivative_if(x): | def second_derivative_if(x): | ||||
| """ second_derivative_if """ | """ second_derivative_if """ | ||||
| return grad(first_derivative_if)(x) | return grad(first_derivative_if)(x) | ||||
| def test_high_order_grad_1(): | |||||
| """ test_high_order_grad_1 """ | |||||
| # 18 | |||||
| assert third_derivative(2) == 18 | |||||
| # 18 * y * y * y, 18 * x * x * x | |||||
| assert third_derivative_dual(4, 5) == (2250, 1152) | |||||
| # 18 * x | |||||
| assert second_derivative_all(3) == 54 | |||||
| def test_high_order_grad_2(): | |||||
| """ test_high_order_grad_2 """ | |||||
| # 2 | |||||
| assert second_derivative_if(12) == 2 | |||||
| def test_high_order_grad_3(): | |||||
| """ test_high_order_grad_2 """ | |||||
| # 6 * x | |||||
| assert second_derivative_if(4) == 24 | |||||
| @@ -325,7 +325,7 @@ def invoke_dataclass2(x, y): | |||||
| def test_access_attr_error(): | def test_access_attr_error(): | ||||
| """ test_access """ | """ test_access """ | ||||
| with pytest.raises(AttributeError): | with pytest.raises(AttributeError): | ||||
| invoke_dataclass2(1, 2) | |||||
| invoke_dataclass2(2, 1) | |||||
| def myfunc(x): | def myfunc(x): | ||||