| @@ -62,8 +62,6 @@ ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalar | |||
| {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, | |||
| {"__ge__", kPrimScalarGe}}; | |||
| const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail"); | |||
| // copy from python API: reduce. | |||
| // Apply a function of two arguments cumulatively to the items of a sequence, | |||
| // from left to right, so as to reduce the sequence to a single value.For example, | |||
| @@ -384,8 +382,8 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { | |||
| .def(py::init<>()); | |||
| })); | |||
| FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { | |||
| MS_EXCEPTION_IF_NULL(a_tuple); | |||
| FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) { | |||
| MS_EXCEPTION_IF_NULL(sequeue); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| @@ -393,31 +391,24 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu | |||
| AnfNodePtr ptrTup = ret->add_parameter(); | |||
| std::vector<AnfNodePtr> elems; | |||
| elems.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| int64_t tuple_size = SizeToLong(a_tuple->size()); | |||
| for (int64_t i = 1; i < tuple_size; ++i) { | |||
| elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)})); | |||
| PrimitivePtr op = nullptr; | |||
| if (sequeue->isa<AbstractTuple>()) { | |||
| elems.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| op = prim::kPrimTupleGetItem; | |||
| } else { | |||
| elems.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| op = prim::kPrimListGetItem; | |||
| } | |||
| ret->set_output(ret->NewCNode(elems)); | |||
| return ret; | |||
| } | |||
| FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { | |||
| MS_EXCEPTION_IF_NULL(a_list); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->debug_info()->set_name("tail"); | |||
| AnfNodePtr ptrList = ret->add_parameter(); | |||
| std::vector<AnfNodePtr> elems; | |||
| elems.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| int64_t list_size = SizeToLong(a_list->size()); | |||
| for (int64_t i = 1; i < list_size; ++i) { | |||
| elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)})); | |||
| for (size_t i = 1; i < sequeue->size(); ++i) { | |||
| if (do_grad_) { | |||
| MS_EXCEPTION_IF_NULL((*sequeue)[i]); | |||
| if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>()) { | |||
| elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | |||
| } | |||
| } else { | |||
| elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | |||
| } | |||
| } | |||
| ret->set_output(ret->NewCNode(elems)); | |||
| @@ -430,14 +421,8 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) | |||
| } | |||
| AbstractBasePtr a = args_spec_list[0]; | |||
| abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(a); | |||
| if (a_tuple != nullptr) { | |||
| return GenerateTupleFuncGraph(a_tuple); | |||
| } | |||
| abstract::AbstractListPtr a_list = dyn_cast<AbstractList>(a); | |||
| if (a_list != nullptr) { | |||
| return GenerateListFuncGraph(a_list); | |||
| if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) { | |||
| return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>()); | |||
| } | |||
| MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); | |||
| @@ -614,7 +599,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An | |||
| CNodePtr inputs_bprop = nullptr; | |||
| if (get_all_) { | |||
| inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp}); | |||
| TailPtr tail = std::make_shared<Tail>("tail", true); | |||
| inputs_bprop = func_graph->NewCNode({NewValueNode(tail), ptr_bapp}); | |||
| } | |||
| // Gradients wrt inputs and parameters | |||
| @@ -99,15 +99,17 @@ extern ValuePtr kCompositeHyperMap; | |||
| class Tail : public MetaFuncGraph { | |||
| public: | |||
| explicit Tail(const std::string &name) : MetaFuncGraph(name) {} | |||
| explicit Tail(const std::string &name, bool do_grad = false) : MetaFuncGraph(name), do_grad_(do_grad) {} | |||
| ~Tail() override = default; | |||
| MS_DECLARE_PARENT(Tail, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); | |||
| FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); | |||
| FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue); | |||
| friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } | |||
| private: | |||
| bool do_grad_; | |||
| }; | |||
| using TailPtr = std::shared_ptr<Tail>; | |||
| @@ -446,10 +446,28 @@ bool TransformTopGraphPass(const ResourcePtr &res) { | |||
| bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); } | |||
| void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> new_paras; | |||
| for (const auto ¶m : func_graph->parameters()) { | |||
| auto param_node = param->cast<ParameterPtr>(); | |||
| if (param_node->has_default()) { | |||
| new_paras.push_back(param_node); | |||
| continue; | |||
| } | |||
| AbstractBasePtr par_abs = param_node->abstract(); | |||
| if (par_abs->isa<abstract::AbstractUndetermined>()) { | |||
| new_paras.push_back(param_node); | |||
| } | |||
| } | |||
| func_graph->set_parameters(new_paras); | |||
| } | |||
| bool ValidatePass(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| Validate(func_graph); | |||
| UpdateFuncGraphParameter(func_graph); | |||
| return true; | |||
| } | |||
| @@ -69,6 +69,10 @@ namespace pipeline { | |||
| using Tensor = mindspore::tensor::Tensor; | |||
| using MetaTensor = mindspore::tensor::MetaTensor; | |||
| using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>; | |||
| using mindspore::abstract::AbstractDictionary; | |||
| using mindspore::abstract::AbstractDictionaryPtr; | |||
| using mindspore::abstract::AbstractList; | |||
| using mindspore::abstract::AbstractListPtr; | |||
| using mindspore::abstract::AbstractTensor; | |||
| using mindspore::abstract::AbstractTensorPtr; | |||
| using mindspore::abstract::AbstractTuple; | |||
| @@ -93,15 +97,10 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) | |||
| return oss.str(); | |||
| } | |||
| void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tensor_arg = arg->cast<TensorPtr>(); | |||
| if (tensor_arg == nullptr) { | |||
| MS_EXCEPTION(TypeError) << "For 'graph mode', the " << idx << "th arg: " << arg->ToString() << " is not a tensor."; | |||
| } | |||
| if (tensor_arg->is_parameter()) { | |||
| MS_EXCEPTION(TypeError) << "The inputs could not be Parameter."; | |||
| } | |||
| AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| bool broaden = value->isa<MetaTensor>(); | |||
| return abstract::FromValue(value, broaden); | |||
| } | |||
| } // namespace | |||
| @@ -117,8 +116,7 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str | |||
| if (!parse::ConvertData(arg.second, &converted)) { | |||
| MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; | |||
| } | |||
| bool broaden = converted->isa<Tensor>() || converted->isa<MetaTensor>(); | |||
| args_spec.push_back(abstract::FromValue(converted, broaden)); | |||
| args_spec.push_back(ArgsToAbstract(converted)); | |||
| } | |||
| if (g_args_cache.count(args_spec) == 0) { | |||
| static int64_t key = 0; | |||
| @@ -484,11 +482,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons | |||
| if (!succ) { | |||
| MS_LOG(EXCEPTION) << "Args convert error"; | |||
| } | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| CheckArgIsTensor(converted, i); | |||
| } | |||
| bool broaden = true; | |||
| args_spec.push_back(abstract::FromValue(converted, broaden)); | |||
| args_spec.push_back(ArgsToAbstract(converted)); | |||
| } | |||
| resource->set_args_spec(args_spec); | |||
| @@ -814,9 +808,6 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { | |||
| if (!parse::ConvertData(args[i], &converted)) { | |||
| MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; | |||
| } | |||
| if (!converted->isa<tensor::Tensor>()) { | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor."; | |||
| } | |||
| } | |||
| } | |||
| return *ret_val; | |||
| @@ -208,7 +208,11 @@ class _MindSporeFunction: | |||
| if context.get_context("precompile_only"): | |||
| return None | |||
| return self._executor(args_list, phase) | |||
| new_inputs = [] | |||
| for i in args_list: | |||
| if isinstance(i, Tensor): | |||
| new_inputs.append(i) | |||
| return self._executor(tuple(new_inputs), phase) | |||
| def ms_function(fn=None, obj=None, input_signature=None): | |||
| @@ -18,7 +18,6 @@ | |||
| from ...composite import base | |||
| from ... import functional as F | |||
| zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True) | |||
| """ | |||
| `zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type | |||
| @@ -31,11 +30,13 @@ def _zeros_like_scala(x): | |||
| """Returns 0 which has the same dtype as x where x is a scalar.""" | |||
| return 0 | |||
| @zeros_like_leaf.register("Bool") | |||
| def _zeros_like_bool(x): | |||
| """Returns False if x is a bool.""" | |||
| return False | |||
| newenv = base.EnvInstance_() | |||
| @@ -100,6 +101,25 @@ def _zeros_like_abstract_error(x): | |||
| return x | |||
| @zeros_like_leaf.register("Dictionary") | |||
| def _zeros_like_dict(x): | |||
| """ | |||
| Derivation of a AbstractError. | |||
| Args: | |||
| x (dict): the input | |||
| Returns: | |||
| dict, keys are same as input's keys, and value are same as zeros_like of input'value. | |||
| """ | |||
| keys = x.keys() | |||
| values = x.values() | |||
| new_values = () | |||
| for ele in values: | |||
| new_values += (zeros_like_leaf(ele),) | |||
| return F.make_dict(keys, new_values) | |||
| # zeros_like is an object that will generate graph of zero_like operation for different type | |||
| zeros_like = base.HyperMap(zeros_like_leaf) | |||
| """`zeros_like` is an object that will generate graph of `zero_like` operation for different type.""" | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test ms_function pass non_tensor inputs""" | |||
| import numpy as np | |||
| from mindspore import Tensor, ms_function, Parameter | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) | |||
| @ms_function | |||
| def compute(x, y, tuple_p, list_q, dict_w): | |||
| return x + y - tuple_p[0] + list_q[1] - dict_w["x"] | |||
| def test_scalar_compute(): | |||
| int_x = 1 | |||
| int_y = 2 | |||
| p = (3, 4) | |||
| q = [5, 6] | |||
| w = {"x": 7, "y": 8} | |||
| ret = compute(int_x, int_y, p, q, w) | |||
| assert ret == -1 | |||
| def test_tensor_compute(): | |||
| tensor_x = Tensor(np.ones((2, 3, 4), np.float32)) | |||
| tensor_y = Tensor(np.ones((2, 3, 4), np.float32) * 2) | |||
| p = (Tensor(np.ones((2, 3, 4), np.float32) * 3), Tensor(np.ones((2, 3, 4), np.float32) * 4)) | |||
| q = [Tensor(np.ones((2, 3, 4), np.float32) * 5), Tensor(np.ones((2, 3, 4), np.float32) * 6)] | |||
| w = {"x": Tensor(np.ones((2, 3, 4), np.float32) * 7), "y": Tensor(np.ones((2, 3, 4), np.float32) * 8)} | |||
| compute(tensor_x, tensor_y, p, q, w) | |||
| @ms_function | |||
| def tensor_reduce(tensor_x, axis, tensor_y): | |||
| reduce_sum = P.ReduceSum() | |||
| ret = reduce_sum(tensor_x, axis) + tensor_y | |||
| return ret | |||
| def test_tensor_reduce(): | |||
| tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32)) | |||
| axis = (0, 1) | |||
| tensor_y = Parameter(Tensor(np.ones((4, 5), np.float32) * 2)) | |||
| tensor_reduce(tensor_x, axis, tensor_y) | |||
| @@ -12,8 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test outermost net pass scalar tuple list dict""" | |||
| import pytest | |||
| """ test outermost net pass non_tensor inputs""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| @@ -28,7 +27,7 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): | |||
| class TestNet(nn.Cell): | |||
| def __init__(self): | |||
| super(TestNet, self).__init__() | |||
| self.support_non_tensor_inputs = True | |||
| self.support_non_tensor_inputs = False | |||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | |||
| return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | |||
| @@ -58,12 +57,5 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): | |||
| forward_net(arg_t1, z, arg_l1, x, 6, args_d1) | |||
| grad_net = GradNet(forward_net) | |||
| with pytest.raises(TypeError) as err: | |||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||
| assert "For 'graph mode', the 0th arg" in str(err.value) | |||
| grad_net.support_non_tensor_inputs = True | |||
| with pytest.raises(ValueError) as err: | |||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||
| assert "Not support set 'support_non_tensor_inputs' to the 'True' for grad net, only support forward net." \ | |||
| in str(err.value) | |||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||
| grad_net(arg_t1, z, arg_l1, x, 6, args_d1) | |||
| @@ -20,9 +20,9 @@ import mindspore.ops.operations as P | |||
| from mindspore import Tensor, context | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from ...ut_filter import non_graph_engine | |||
| # pylint: disable=unused-argument | |||
| def setup_module(module): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| @@ -86,24 +86,6 @@ def test_cast_grad(): | |||
| assert np.all(gout[0].asnumpy() == expect) | |||
| def test_scalar_cast_grad(): | |||
| """ test_scalar_cast_grad """ | |||
| input_x = 255.5 | |||
| input_t = ms.int8 | |||
| def fx_cast(x): | |||
| output = F.scalar_cast(x, input_t) | |||
| return output | |||
| @ms_function | |||
| def grad_fx_cast(input_x): | |||
| return grad(fx_cast)(input_x) | |||
| gfn = grad_fx_cast(input_x) | |||
| expect_dx = 1 | |||
| assert gfn == expect_dx | |||
| @non_graph_engine | |||
| def test_reshape_grad(): | |||
| """ test_reshape_grad """ | |||
| @@ -14,31 +14,28 @@ | |||
| # ============================================================================ | |||
| """ test_framstruct """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.utils.check_gradient import ( | |||
| ms_function, check_jacobian, Tensor, NNGradChecker, | |||
| OperationGradChecker, check_gradient, ScalarGradChecker) | |||
| OperationGradChecker, check_gradient) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| def setup_module(module): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| grad = C.GradOperation() | |||
| grad_all = C.GradOperation(get_all=True) | |||
| grad_by_list = C.GradOperation(get_by_list=True) | |||
| grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) | |||
| @ms_function | |||
| @@ -79,9 +76,7 @@ def dynamic_make_tuple(x, lower, upper): | |||
| def test_dynamic_make_tuple(): | |||
| # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language. | |||
| with pytest.raises(RuntimeError): | |||
| dynamic_make_tuple(2, 1, 5) | |||
| assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) | |||
| def test_make_tuple(): | |||
| @@ -273,15 +268,6 @@ def rec(x): | |||
| return rec(x - 1) | |||
| return x | |||
| @ms_function | |||
| def grad_rec(input_x): | |||
| return grad(rec)(input_x) | |||
| def test_grad_rec(): | |||
| """ test_grad_rec """ | |||
| res = grad_rec(3) | |||
| assert res == 1 | |||
| def test_me_rec(): | |||
| """ test_me_rec """ | |||
| @@ -303,13 +289,6 @@ def test_while2(): | |||
| assert res == 6 | |||
| def test_grad_while2(): | |||
| @ms_function | |||
| def df_t2_while(input_x, input_y): | |||
| return grad(t2_while)(input_x, input_y) | |||
| assert df_t2_while(2, 3) == 3 | |||
| def if_test(a, b): | |||
| """ if_test """ | |||
| if a > b: | |||
| @@ -327,24 +306,6 @@ def test_grad_if(): | |||
| assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) | |||
| # While loop is not unrolled in forward and backward graphs. | |||
| def test_dont_unroll_while(): | |||
| def dont_unroll_while(x, y): | |||
| i = 2 | |||
| out = y - x | |||
| while i < 10: | |||
| out = mul(x, y) | |||
| i = i + 1 | |||
| return out | |||
| @ms_function() | |||
| def invoke_while(x, y): | |||
| return grad(dont_unroll_while)(x, y) | |||
| res = invoke_while(2, 3) | |||
| assert res == 3 | |||
| class ConvNet(nn.Cell): | |||
| def __init__(self): | |||
| super(ConvNet, self).__init__() | |||
| @@ -445,13 +406,6 @@ def test_factorial(): | |||
| assert res == 6 | |||
| def test_grad_factorial(): | |||
| @ms_function | |||
| def df_factorial(x): | |||
| return grad(factorial)(x) | |||
| assert df_factorial(3) == 11 | |||
| @ms_function | |||
| def factorial2(n): | |||
| """ factorial """ | |||
| @@ -523,17 +477,13 @@ def _for(x): | |||
| ret = ret * i | |||
| return ret | |||
| @ms_function | |||
| def grad_for(x): | |||
| """ grad_for """ | |||
| return grad_all(_for)(x) | |||
| def test_grad_for(): | |||
| """ test_grad_for """ | |||
| assert grad_for(5) == (60,) | |||
| @ms_function | |||
| def try_tail(x): | |||
| """ try_tail """ | |||
| @@ -675,15 +625,6 @@ def test_arithmetic_simplify_08(): | |||
| assert np.all(res.asnumpy() == expect) | |||
| def test_ScalarGradChecker(): | |||
| """ test_ScalarGradChecker """ | |||
| def scalar_f(x, y): | |||
| return x * y | |||
| check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1) | |||
| def test_GradCheckerPrimitive(): | |||
| """ test_GradCheckerPrimitive """ | |||
| matmul = P.MatMul() | |||
| @@ -737,15 +678,6 @@ def test_OperationGradChecker(): | |||
| input_selector=[1], sampling_times=2) | |||
| def test_ScalarJacobianChecker(): | |||
| """ test_ScalarJacobianChecker """ | |||
| def scalar_f(x, y): | |||
| return x * y | |||
| check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0]) | |||
| def test_OperationJacobianChecker(): | |||
| """ test_OperationJacobianChecker """ | |||
| @@ -795,13 +727,6 @@ def multi_outputs(x, y): | |||
| return 2 * z, 2 * z | |||
| def test_grad_multi_outputs(): | |||
| @ms_function | |||
| def df_multi_outputs(x, y): | |||
| return grad_all_with_sens(multi_outputs)(x, y, (1, 1)) | |||
| assert df_multi_outputs(2, 3) == (4, 4) | |||
| @ms_function | |||
| def while_sp(x, y, z): | |||
| out = x | |||
| @@ -874,13 +799,6 @@ def grad_refactor_3(a): | |||
| return 3 * a | |||
| def test_grad_refactor_3(): | |||
| @ms_function | |||
| def df_refactor_3(x): | |||
| return grad_all(grad_refactor_3)(x) | |||
| assert df_refactor_3(3) == (3,) | |||
| def grad_refactor_4(a): | |||
| """ if_test """ | |||
| if a > 3: | |||
| @@ -899,13 +817,6 @@ def grad_refactor_5(a): | |||
| return a | |||
| def test_grad_refactor_5(): | |||
| @ms_function | |||
| def df_refactor_5(x): | |||
| return grad_all(grad_refactor_5)(x) | |||
| assert df_refactor_5(1) == (1,) | |||
| def grad_refactor_6(a, b): | |||
| """ if_test """ | |||
| if a > b: | |||
| @@ -925,13 +836,6 @@ def grad_refactor_while(x): | |||
| return rval | |||
| def test_grad_refactor_9(): | |||
| @ms_function | |||
| def df_refactor_while(input_x): | |||
| return grad_all(grad_refactor_while)(input_x) | |||
| assert df_refactor_while(3) == (6,) | |||
| def grad_refactor__while_1(x): | |||
| """ _while """ | |||
| ret = x * x | |||
| @@ -1009,13 +913,6 @@ def grad_refactor_14(a, b): | |||
| return inner1(b) + inner2(a) + inner3(a) | |||
| def test_grad_refactor_14(): | |||
| @ms_function | |||
| def df_refactor_14(x, y): | |||
| return grad_all(grad_refactor_14)(x, y) | |||
| assert df_refactor_14(2, 3) == (3, 9) | |||
| # pylint: disable=using-constant-test | |||
| class IfDeferInline(nn.Cell): | |||
| def __init__(self, mul_size): | |||
| @@ -1044,6 +941,8 @@ def test_dict_const(): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.res = {'1': 10} | |||
| def construct(self): | |||
| return self.res | |||
| Net()() | |||
| @@ -109,25 +109,3 @@ def first_derivative_if(x): | |||
| def second_derivative_if(x): | |||
| """ second_derivative_if """ | |||
| return grad(first_derivative_if)(x) | |||
| def test_high_order_grad_1(): | |||
| """ test_high_order_grad_1 """ | |||
| # 18 | |||
| assert third_derivative(2) == 18 | |||
| # 18 * y * y * y, 18 * x * x * x | |||
| assert third_derivative_dual(4, 5) == (2250, 1152) | |||
| # 18 * x | |||
| assert second_derivative_all(3) == 54 | |||
| def test_high_order_grad_2(): | |||
| """ test_high_order_grad_2 """ | |||
| # 2 | |||
| assert second_derivative_if(12) == 2 | |||
| def test_high_order_grad_3(): | |||
| """ test_high_order_grad_2 """ | |||
| # 6 * x | |||
| assert second_derivative_if(4) == 24 | |||
| @@ -325,7 +325,7 @@ def invoke_dataclass2(x, y): | |||
| def test_access_attr_error(): | |||
| """ test_access """ | |||
| with pytest.raises(AttributeError): | |||
| invoke_dataclass2(1, 2) | |||
| invoke_dataclass2(2, 1) | |||
| def myfunc(x): | |||