| @@ -34,6 +34,7 @@ | |||||
| #include "pybind_api/api_register.h" | #include "pybind_api/api_register.h" | ||||
| #include "ir/signature.h" | #include "ir/signature.h" | ||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support composite operators definition | // namespace to support composite operators definition | ||||
| @@ -403,7 +404,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & | |||||
| if (tail_type_ == kGradFirst) { | if (tail_type_ == kGradFirst) { | ||||
| if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && | if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && | ||||
| ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || | ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || | ||||
| ((*sequeue)[1]->BuildType() != nullptr && (*sequeue)[1]->BuildType()->isa<Number>()))) { | |||||
| (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr && | |||||
| (*sequeue)[1]->BuildType()->isa<Number>()))) { | |||||
| ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); | ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); | ||||
| } else { | } else { | ||||
| ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); | ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); | ||||
| @@ -416,7 +418,8 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & | |||||
| if (tail_type_ == kGradAll) { | if (tail_type_ == kGradAll) { | ||||
| MS_EXCEPTION_IF_NULL((*sequeue)[i]); | MS_EXCEPTION_IF_NULL((*sequeue)[i]); | ||||
| if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || | if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || | ||||
| ((*sequeue)[i]->BuildType() != nullptr && (*sequeue)[i]->BuildType()->isa<Number>())) { | |||||
| (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr && | |||||
| (*sequeue)[i]->BuildType()->isa<Number>())) { | |||||
| elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -490,7 +490,8 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| AbstractBasePtr par_abs = param_node->abstract(); | AbstractBasePtr par_abs = param_node->abstract(); | ||||
| if (par_abs->isa<abstract::AbstractUndetermined>() || | if (par_abs->isa<abstract::AbstractUndetermined>() || | ||||
| (par_abs->BuildType() != nullptr && par_abs->BuildType()->isa<Number>())) { | |||||
| (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr && | |||||
| par_abs->BuildType()->isa<Number>())) { | |||||
| new_paras.push_back(param_node); | new_paras.push_back(param_node); | ||||
| } | } | ||||
| } | } | ||||
| @@ -98,7 +98,8 @@ std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) | |||||
| AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { | AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { | ||||
| MS_EXCEPTION_IF_NULL(value); | MS_EXCEPTION_IF_NULL(value); | ||||
| bool broaden = value->isa<MetaTensor>() || value->isa<Scalar>(); | |||||
| bool broaden = value->isa<MetaTensor>() || | |||||
| (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>()); | |||||
| return abstract::FromValue(value, broaden); | return abstract::FromValue(value, broaden); | ||||
| } | } | ||||
| @@ -95,7 +95,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { | |||||
| .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) | .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) | ||||
| .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) | .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) | ||||
| .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) | .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) | ||||
| .value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH); | |||||
| .value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH) | |||||
| .value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR); | |||||
| (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext") | (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext") | ||||
| .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") | .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") | ||||
| .def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.") | .def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.") | ||||
| @@ -210,7 +210,9 @@ class _MindSporeFunction: | |||||
| return None | return None | ||||
| new_inputs = [] | new_inputs = [] | ||||
| for i in args_list: | for i in args_list: | ||||
| if isinstance(i, (Tensor, int, float)): | |||||
| if isinstance(i, Tensor): | |||||
| new_inputs.append(i) | |||||
| elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): | |||||
| new_inputs.append(i) | new_inputs.append(i) | ||||
| return self._executor(tuple(new_inputs), phase) | return self._executor(tuple(new_inputs), phase) | ||||
| @@ -533,6 +533,7 @@ def set_context(**kwargs): | |||||
| save_graphs variable_memory_max_size | save_graphs variable_memory_max_size | ||||
| save_graphs_path | save_graphs_path | ||||
| env_config_path | env_config_path | ||||
| grad_for_scalar | |||||
| =========================== =========================== ================= | =========================== =========================== ================= | ||||
| Args: | Args: | ||||
| @@ -602,6 +603,7 @@ def set_context(**kwargs): | |||||
| enable_sparse (bool): Whether to enable sparsity feature. Default: False. | enable_sparse (bool): Whether to enable sparsity feature. Default: False. | ||||
| max_call_depth (int): Specify the maximum depth of function call. Default: 1000. | max_call_depth (int): Specify the maximum depth of function call. Default: 1000. | ||||
| env_config_path (str): Config path for DFX. | env_config_path (str): Config path for DFX. | ||||
| grad_for_scalar (bool): Whether to get gradient for scalar. Default: False. | |||||
| Raises: | Raises: | ||||
| ValueError: If input key is not an attribute in context. | ValueError: If input key is not an attribute in context. | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "abstract/utils.h" | #include "abstract/utils.h" | ||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| @@ -88,7 +89,13 @@ std::string AbstractBase::ToString() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } | |||||
| AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { | |||||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) { | |||||
| return AbstractBase::Broaden(config); | |||||
| } else { | |||||
| return Clone(); | |||||
| } | |||||
| } | |||||
| AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | ||||
| MS_EXCEPTION_IF_NULL(other); | MS_EXCEPTION_IF_NULL(other); | ||||
| @@ -171,6 +171,12 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| return args_spec_list[0]; | return args_spec_list[0]; | ||||
| } | } | ||||
| auto depends = args_spec_list[0]->Broaden(); | auto depends = args_spec_list[0]->Broaden(); | ||||
| if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) { | |||||
| // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value. | |||||
| if (depends->isa<AbstractScalar>()) { | |||||
| depends->set_value(kAnyValue); | |||||
| } | |||||
| } | |||||
| return depends; | return depends; | ||||
| } | } | ||||
| @@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||||
| set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false); | set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false); | ||||
| set_param<bool>(MS_CTX_ENABLE_SPARSE, false); | set_param<bool>(MS_CTX_ENABLE_SPARSE, false); | ||||
| set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false); | set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false); | ||||
| set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false); | |||||
| backend_policy_ = policy_map_[policy]; | backend_policy_ = policy_map_[policy]; | ||||
| } | } | ||||
| @@ -76,6 +76,7 @@ enum MsCtxParam : unsigned { | |||||
| MS_CTX_SAVE_GRAPHS_FLAG, | MS_CTX_SAVE_GRAPHS_FLAG, | ||||
| MS_CTX_ENABLE_PARALLEL_SPLIT, | MS_CTX_ENABLE_PARALLEL_SPLIT, | ||||
| MS_CTX_ENABLE_INFER_OPT, | MS_CTX_ENABLE_INFER_OPT, | ||||
| MS_CTX_GRAD_FOR_SCALAR, | |||||
| MS_CTX_TYPE_BOOL_END, | MS_CTX_TYPE_BOOL_END, | ||||
| // parameter of type int | // parameter of type int | ||||
| @@ -609,7 +609,9 @@ class Cell(Cell_): | |||||
| new_inputs = [] | new_inputs = [] | ||||
| for i in inputs: | for i in inputs: | ||||
| if isinstance(i, (Tensor, int, float)): | |||||
| if isinstance(i, Tensor): | |||||
| new_inputs.append(i) | |||||
| elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): | |||||
| new_inputs.append(i) | new_inputs.append(i) | ||||
| if self._auto_parallel_mode: | if self._auto_parallel_mode: | ||||
| @@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) { | |||||
| AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false); | AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false); | ||||
| AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false); | AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false); | ||||
| AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true); | AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true); | ||||
| abs_s_anything->set_value(kAnyValue); | |||||
| AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); | AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); | ||||
| ASSERT_EQ(*res_s1, *abs_s_anything); | ASSERT_EQ(*res_s1, *abs_s_anything); | ||||
| // AbstractTuple join; | |||||
| std::vector<int64_t> list1 = {1, 2, 3, 4, 5}; | |||||
| std::vector<int64_t> list2 = {5, 4, 3, 2, 1}; | |||||
| AbstractBasePtr abs_t1 = FromValue(list1, true); | |||||
| AbstractBasePtr abs_t2 = FromValue(list2, true); | |||||
| AbstractBasePtr res_t1 = abs_t1->Join(abs_t2); | |||||
| ASSERT_EQ(res_t1, abs_t1); | |||||
| abs_s1 = FromValue(static_cast<int64_t>(1), false); | abs_s1 = FromValue(static_cast<int64_t>(1), false); | ||||
| AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); | AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); | ||||
| AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); | AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything})); | ||||
| AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything})); | AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything})); | ||||
| res_t1 = t1->Join(t2); | |||||
| AbstractBasePtr res_t1 = t1->Join(t2); | |||||
| ASSERT_EQ(res_t1, t1); | ASSERT_EQ(res_t1, t1); | ||||
| res_t1 = t1->Join(t3); | res_t1 = t1->Join(t3); | ||||
| @@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) { | |||||
| // add infer and renormalize | // add infer and renormalize | ||||
| std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>(); | std::shared_ptr<mindspore::pipeline::Resource> res = std::make_shared<mindspore::pipeline::Resource>(); | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast<int64_t>(1), true); | |||||
| AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast<int64_t>(2), true); | |||||
| tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3}); | |||||
| tensor::TensorPtr y_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{2, 3}); | |||||
| AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true); | |||||
| AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true); | |||||
| args_spec_list.push_back(abstract_v1); | args_spec_list.push_back(abstract_v1); | ||||
| args_spec_list.push_back(abstract_v2); | args_spec_list.push_back(abstract_v2); | ||||
| AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); | AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); | ||||
| @@ -74,20 +74,17 @@ TEST_F(TestData, test_build_value) { | |||||
| AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false); | AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false); | ||||
| AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2})); | AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2})); | ||||
| ValuePtr func_tuple_built = abs_func_tuple->BuildValue(); | ValuePtr func_tuple_built = abs_func_tuple->BuildValue(); | ||||
| ASSERT_EQ(*func_tuple_built, | |||||
| ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); | |||||
| ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); | |||||
| // BuildValue(List(AbstractFunction)) should return kAnyValue; | // BuildValue(List(AbstractFunction)) should return kAnyValue; | ||||
| AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2})); | AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2})); | ||||
| ValuePtr func_list_built = abs_func_list->BuildValue(); | ValuePtr func_list_built = abs_func_list->BuildValue(); | ||||
| ASSERT_EQ(*func_list_built, | |||||
| ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); | |||||
| ASSERT_EQ(*func_list_built, ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); | |||||
| // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue | // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue | ||||
| abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2})); | abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2})); | ||||
| func_tuple_built = abs_func_tuple->BuildValue(); | func_tuple_built = abs_func_tuple->BuildValue(); | ||||
| ASSERT_EQ(*func_tuple_built, | |||||
| ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd})); | |||||
| ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd})); | |||||
| } | } | ||||
| TEST_F(TestData, test_build_type) { | TEST_F(TestData, test_build_type) { | ||||
| @@ -129,7 +126,7 @@ TEST_F(TestData, test_build_shape) { | |||||
| AbstractBasePtr abstract_tup = FromValue(vec, true); | AbstractBasePtr abstract_tup = FromValue(vec, true); | ||||
| std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape()); | std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape()); | ||||
| ASSERT_TRUE(shape_tuple); | ASSERT_TRUE(shape_tuple); | ||||
| const std::vector<BaseShapePtr>& ptr_vec = shape_tuple->shape(); | |||||
| const std::vector<BaseShapePtr> &ptr_vec = shape_tuple->shape(); | |||||
| ASSERT_EQ(ptr_vec.size(), 2); | ASSERT_EQ(ptr_vec.size(), 2); | ||||
| ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]); | ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]); | ||||
| @@ -148,14 +145,14 @@ TEST_F(TestData, test_clone) { | |||||
| ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack()); | ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack()); | ||||
| ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack()); | ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack()); | ||||
| AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), | |||||
| AnalysisContext::DummyContext()); | |||||
| AbstractFunctionPtr f1 = | |||||
| std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext()); | |||||
| AbstractBasePtr f2 = f1->Clone(); | AbstractBasePtr f2 = f1->Clone(); | ||||
| ASSERT_TRUE(*f2 == *f1); | ASSERT_TRUE(*f2 == *f1); | ||||
| AbstractList l1 = AbstractList({s1, s2}); | AbstractList l1 = AbstractList({s1, s2}); | ||||
| AbstractBasePtr l2 = l1.Clone(); | AbstractBasePtr l2 = l1.Clone(); | ||||
| AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get()); | |||||
| AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get()); | |||||
| ASSERT_TRUE(l2_cast != nullptr); | ASSERT_TRUE(l2_cast != nullptr); | ||||
| ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack()); | ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack()); | ||||
| @@ -184,19 +181,19 @@ TEST_F(TestData, test_broaden) { | |||||
| AbstractBasePtr s2 = s1->Broaden(); | AbstractBasePtr s2 = s1->Broaden(); | ||||
| ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); | ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); | ||||
| ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); | ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); | ||||
| ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>()); | |||||
| ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>()); | |||||
| AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), | |||||
| AnalysisContext::DummyContext()); | |||||
| AbstractFunctionPtr f1 = | |||||
| std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext()); | |||||
| AbstractBasePtr f2 = f1->Broaden(); | AbstractBasePtr f2 = f1->Broaden(); | ||||
| ASSERT_TRUE(f2 == f1); | ASSERT_TRUE(f2 == f1); | ||||
| AbstractList l1 = AbstractList({s1, s2}); | AbstractList l1 = AbstractList({s1, s2}); | ||||
| AbstractBasePtr l2 = l1.Broaden(); | AbstractBasePtr l2 = l1.Broaden(); | ||||
| AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get()); | |||||
| AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get()); | |||||
| ASSERT_TRUE(l2_cast != nullptr); | ASSERT_TRUE(l2_cast != nullptr); | ||||
| AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); | AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); | ||||
| ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>()); | |||||
| ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>()); | |||||
| } | } | ||||
| } // namespace abstract | } // namespace abstract | ||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test_framstruct """ | """ test_framstruct """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| @@ -76,9 +75,7 @@ def dynamic_make_tuple(x, lower, upper): | |||||
| def test_dynamic_make_tuple(): | def test_dynamic_make_tuple(): | ||||
| # Dynamically recursively creating static type is invalid in mindspore, as mindspore is a static language. | |||||
| with pytest.raises(RuntimeError): | |||||
| dynamic_make_tuple(2, 1, 5) | |||||
| assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2) | |||||
| def test_make_tuple(): | def test_make_tuple(): | ||||