Merge pull request !3820 from zhangbuxue/support_dtype_and_shape_as_attr_in_graph_modetags/v0.7.0-beta
| @@ -28,7 +28,8 @@ from ...ops.composite.base import _append | |||
| __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] | |||
| trans = P.Transpose() | |||
| shape_ = P.Shape() | |||
| dtype_ = P.DType() | |||
| def transpose(x): | |||
| """Implementation of `transpose`.""" | |||
| @@ -93,7 +93,6 @@ inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("arra | |||
| inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | |||
| inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | |||
| inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | |||
| inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||
| inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "abstract/utils.h" | |||
| #include "frontend/operator/cc_implementations.h" | |||
| #include "abstract/param_validator.h" | |||
| @@ -80,23 +79,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti | |||
| return std::make_shared<AbstractTuple>(elems); | |||
| } | |||
| AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString(); | |||
| AbstractBasePtrList values; | |||
| auto shp = arg->shape(); | |||
| for (int entry : shp->shape()) { | |||
| auto entry_v = MakeValue(entry); | |||
| values.push_back(std::make_shared<AbstractScalar>(entry_v, entry_v->type())); | |||
| } | |||
| return std::make_shared<AbstractTuple>(values); | |||
| } | |||
| AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor and a tuple. | |||
| @@ -985,6 +985,7 @@ void ClearResAtexit() { | |||
| abstract::ClearPrimEvaluatorMap(); | |||
| compile::ClearConvertCache(); | |||
| pipeline::GetMethodMap().clear(); | |||
| pipeline::GetAttrMap().clear(); | |||
| pipeline::ExecutorPy::ClearRes(); | |||
| pipeline::ReclaimOptimizer(); | |||
| pynative::PynativeExecutor::GetInstance()->ClearRes(); | |||
| @@ -17,23 +17,20 @@ | |||
| */ | |||
| #include "pipeline/jit/resource.h" | |||
| #include "pipeline/jit/pipeline.h" | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| #include "debug/draw.h" | |||
| #include "debug/trace.h" | |||
| #include "ir/dtype.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "frontend/optimizer/ad/dfunctor.h" | |||
| #include "vm/segment_runner.h" | |||
| namespace mindspore { | |||
| // namespace to support opmap definition | |||
| namespace pipeline { | |||
| MethodMap &GetMethodMap() { | |||
| static MethodMap method_map = { | |||
| BuiltInTypeMap &GetMethodMap() { | |||
| static BuiltInTypeMap method_map = { | |||
| {kObjectTypeString, | |||
| { | |||
| {"__bool__", std::string("str_bool")} // C.str_bool | |||
| @@ -191,6 +188,15 @@ MethodMap &GetMethodMap() { | |||
| return method_map; | |||
| } | |||
| BuiltInTypeMap &GetAttrMap() { | |||
| static BuiltInTypeMap attr_map = {{kObjectTypeTensorType, | |||
| { | |||
| {"shape", std::string("shape_")}, // C.shape_ | |||
| {"dtype", std::string("dtype_")}, // C.dtype_ | |||
| }}}; | |||
| return attr_map; | |||
| } | |||
| Resource::Resource(const py::object &obj) | |||
| : engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)), | |||
| input_(obj), | |||
| @@ -218,31 +224,42 @@ Resource::~Resource() { | |||
| } | |||
| } | |||
| bool Resource::IsTypeInMethodMap(const TypeId &type) { | |||
| TypeId type_id = NormalizeTypeId(type); | |||
| const MethodMap &method_map = GetMethodMap(); | |||
| auto iter = method_map.find(static_cast<int>(type_id)); | |||
| if (iter != method_map.end()) { | |||
| return true; | |||
| Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) { | |||
| auto type_method_map = method_map.find(static_cast<int>(type_id)); | |||
| if (type_method_map == method_map.end()) { | |||
| return Any(); | |||
| } | |||
| return false; | |||
| auto method = type_method_map->second.find(name); | |||
| if (method == type_method_map->second.end()) { | |||
| return Any(); | |||
| } | |||
| return method->second; | |||
| } | |||
| Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { | |||
| bool Resource::IsTypeInBuiltInMap(const TypeId &type) { | |||
| TypeId type_id = NormalizeTypeId(type); | |||
| const MethodMap &method_map = GetMethodMap(); | |||
| const BuiltInTypeMap &method_map = GetMethodMap(); | |||
| auto iter = method_map.find(static_cast<int>(type_id)); | |||
| if (iter == method_map.end()) { | |||
| MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; | |||
| return Any(); | |||
| const BuiltInTypeMap &attr_map = GetAttrMap(); | |||
| iter = attr_map.find(static_cast<int>(type_id)); | |||
| if (iter == attr_map.end()) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| auto iter_map = iter->second.find(name); | |||
| if (iter_map == iter->second.end()) { | |||
| MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name; | |||
| return Any(); | |||
| } | |||
| return iter_map->second; | |||
| Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { | |||
| TypeId type_id = NormalizeTypeId(type); | |||
| const BuiltInTypeMap &method_map = GetMethodMap(); | |||
| return GetMethodOrAttr(name, type_id, method_map); | |||
| } | |||
| Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) { | |||
| TypeId type_id = NormalizeTypeId(type); | |||
| const BuiltInTypeMap &attr_map = GetAttrMap(); | |||
| return GetMethodOrAttr(name, type_id, attr_map); | |||
| } | |||
| void Resource::Clean() { | |||
| @@ -44,9 +44,11 @@ const char kOutput[] = "output"; | |||
| class InferenceResource; | |||
| using MethodMap = std::unordered_map<int, std::unordered_map<std::string, Any>>; | |||
| using BuiltInTypeMap = std::unordered_map<int, std::unordered_map<std::string, Any>>; | |||
| MethodMap &GetMethodMap(); | |||
| BuiltInTypeMap &GetMethodMap(); | |||
| BuiltInTypeMap &GetAttrMap(); | |||
| class ResourceBase { | |||
| public: | |||
| @@ -87,10 +89,12 @@ class Resource : public ResourceBase { | |||
| abstract::AnalysisEnginePtr engine() { return engine_; } | |||
| static bool IsTypeInMethodMap(const TypeId &type); | |||
| static bool IsTypeInBuiltInMap(const TypeId &type); | |||
| static Any GetMethodPtr(const TypeId &type, const std::string &name); | |||
| static Any GetAttrPtr(const TypeId &type, const std::string &name); | |||
| const py::object &input() const { return input_; } | |||
| FuncGraphPtr func_graph() const { return func_graph_; } | |||
| @@ -21,7 +21,6 @@ | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include <mutex> | |||
| #include <set> | |||
| #include <string> | |||
| #include <utility> | |||
| @@ -31,10 +30,8 @@ | |||
| #include "frontend/operator/prim_to_function.h" | |||
| #include "abstract/utils.h" | |||
| #include "utils/symbolic.h" | |||
| #include "./common.h" | |||
| #include "pipeline/jit/resource.h" | |||
| #include "pipeline/jit/parse/resolve.h" | |||
| #include "ir/tensor.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| @@ -64,7 +61,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | |||
| {prim::kPrimShape, {InferImplShape, true}}, | |||
| {prim::kPrimPack, {InferImplPack, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| @@ -634,7 +630,7 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm | |||
| } | |||
| const int kResolveCaseUserDefineClass = 1; | |||
| const int kResolveCaseBuildinTypeMethod = 2; | |||
| const int kResolveCaseBuiltInType = 2; | |||
| const int kResolveCaseFunction = 3; | |||
| int GetResolveCase(const TypePtr &data_type) { | |||
| MS_EXCEPTION_IF_NULL(data_type); | |||
| @@ -643,8 +639,8 @@ int GetResolveCase(const TypePtr &data_type) { | |||
| } | |||
| // try method map, if not in method map, the data_type should be External type. | |||
| if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) { | |||
| return kResolveCaseBuildinTypeMethod; | |||
| if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) { | |||
| return kResolveCaseBuiltInType; | |||
| } | |||
| return kResolveCaseFunction; | |||
| @@ -674,8 +670,10 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun | |||
| manager->AddFuncGraph(func_graph); | |||
| } | |||
| EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, | |||
| const AnfNodeConfigPtr &old_conf) { | |||
| enum REQUIRE_TYPE { ATTR, METHOD }; | |||
| EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf, | |||
| REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) { | |||
| MS_EXCEPTION_IF_NULL(old_conf); | |||
| AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); | |||
| @@ -701,6 +699,9 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_ | |||
| MS_EXCEPTION_IF_NULL(old_conf); | |||
| FuncGraphPtr func_graph = old_conf->node()->func_graph(); | |||
| CNodePtr new_cnode = func_graph->NewCNode(input); | |||
| if (require_type == REQUIRE_TYPE::ATTR) { | |||
| new_cnode = func_graph->NewCNode({new_cnode}); | |||
| } | |||
| AnalysisEnginePtr eng = old_conf->engine(); | |||
| AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context()); | |||
| return eng->ForwardConfig(old_conf, fn_conf); | |||
| @@ -781,9 +782,9 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng | |||
| return StaticGetterInferred(converted_v, data_conf, out_conf); | |||
| } | |||
| EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, | |||
| const TypePtr &data_type, const ConfigPtr &data_conf, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, | |||
| const TypePtr &data_type, const ConfigPtr &data_conf, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| MS_EXCEPTION_IF_NULL(item_v); | |||
| MS_EXCEPTION_IF_NULL(data_type); | |||
| // The method maybe a Primitive or Composite | |||
| @@ -792,22 +793,29 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &eng | |||
| } | |||
| std::string item_name = item_v->cast<StringImmPtr>()->value(); | |||
| Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); | |||
| if (method.empty()) { | |||
| MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name; | |||
| REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD; | |||
| Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); | |||
| if (require.empty()) { | |||
| require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name); | |||
| if (require.empty()) { | |||
| MS_LOG(EXCEPTION) << "The object of type: " << data_type->ToString() << " has no method or attr: " << item_name; | |||
| } | |||
| require_type = REQUIRE_TYPE::ATTR; | |||
| } | |||
| ValuePtr converted_v = nullptr; | |||
| if (method.is<std::string>()) { | |||
| if (require.is<std::string>()) { | |||
| // composite registered in standard_method_map go to this branch | |||
| converted_v = prim::GetPythonOps(method.cast<std::string>()); | |||
| AddToManager(engine, converted_v->cast<FuncGraphPtr>()); | |||
| } else if (method.is<PrimitivePtr>()) { | |||
| converted_v = method.cast<PrimitivePtr>(); | |||
| converted_v = prim::GetPythonOps(require.cast<std::string>()); | |||
| if (!converted_v->isa<Primitive>()) { | |||
| AddToManager(engine, converted_v->cast<FuncGraphPtr>()); | |||
| } | |||
| } else if (require.is<PrimitivePtr>()) { | |||
| converted_v = require.cast<PrimitivePtr>(); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString(); | |||
| MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString(); | |||
| } | |||
| return StaticGetterInferred(converted_v, data_conf, out_conf); | |||
| return StaticGetterInferred(converted_v, data_conf, out_conf, require_type); | |||
| } | |||
| EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| @@ -831,8 +839,8 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt | |||
| int case_v = GetResolveCase(data_type); | |||
| if (case_v == kResolveCaseUserDefineClass) { | |||
| return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); | |||
| } else if (case_v == kResolveCaseBuildinTypeMethod) { | |||
| return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf); | |||
| } else if (case_v == kResolveCaseBuiltInType) { | |||
| return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf); | |||
| } else { | |||
| return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); | |||
| } | |||
| @@ -218,10 +218,6 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -246,8 +242,6 @@ AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const Primitiv | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| @@ -22,20 +22,21 @@ import copy | |||
| import functools | |||
| import itertools | |||
| import numbers | |||
| import numpy as np | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor | |||
| from ...common.parameter import Parameter | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from .._utils import get_concat_offset | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | |||
| from ..._c_expression import signature_dtype as sig_dtype | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import typing | |||
| from ..._checkparam import Rel | |||
| from ..._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ...common.parameter import Parameter | |||
| from ...common.tensor import Tensor | |||
| class _ScatterOp(PrimitiveWithInfer): | |||
| @@ -415,7 +416,7 @@ class Reshape(PrimitiveWithInfer): | |||
| return out | |||
| class Shape(Primitive): | |||
| class Shape(PrimitiveWithInfer): | |||
| """ | |||
| Returns the shape of input tensor. | |||
| @@ -436,6 +437,13 @@ class Shape(Primitive): | |||
| def __init__(self): | |||
| """init Shape""" | |||
| def __infer__(self, x): | |||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) | |||
| out = {'shape': (), | |||
| 'dtype': mstype.tuple_, | |||
| 'value': tuple(x['shape'])} | |||
| return out | |||
| class Squeeze(PrimitiveWithInfer): | |||
| """ | |||
| @@ -267,11 +267,6 @@ TEST_F(TestOps, BroadCastShapeTest) { | |||
| ASSERT_EQ(prim->name(), kPrimBroadcastShape->name()); | |||
| } | |||
| TEST_F(TestOps, ShapeTest) { | |||
| auto prim = std::make_shared<Primitive>("Shape"); | |||
| ASSERT_EQ(prim->name(), kPrimShape->name()); | |||
| } | |||
| TEST_F(TestOps, ArrayMapTest) { | |||
| auto prim = std::make_shared<Primitive>("array_map"); | |||
| ASSERT_EQ(prim->name(), kPrimArrayMap->name()); | |||
| @@ -36,23 +36,23 @@ class TestResource : public UT::Common { | |||
| void TearDown() {} | |||
| }; | |||
| TEST_F(TestResource, test_standard_method_map) { | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt8)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt16)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt32)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt64)); | |||
| TEST_F(TestResource, test_built_in_type_map) { | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt8)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt16)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt32)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt64)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat16)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat32)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat64)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat16)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat32)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat64)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeBool)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeUInt)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTuple)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeList)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTensorType)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeBool)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeUInt)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTuple)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeList)); | |||
| ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTensorType)); | |||
| MethodMap& map = GetMethodMap(); | |||
| for (auto& iter : map) { | |||
| @@ -467,24 +467,6 @@ TEST_F(TestPrim, test_env_add) { | |||
| ASSERT_TRUE(*res == *exp); | |||
| } | |||
| TEST_F(TestPrim, test_shape) { | |||
| PrimitivePtr shap = std::make_shared<Primitive>("Shape"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(shap, 1); | |||
| auto a = UTPrimUtils::ArrayFloat64Of({2, 3}); | |||
| AbstractBasePtrList args_spec_list = {a}; | |||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||
| auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | |||
| std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)}; | |||
| ASSERT_TRUE(ret.size() == element_list.size()); | |||
| for (int i = 0; i < element_list.size(); i++) { | |||
| ASSERT_TRUE(*ret[i] == *element_list[i]); | |||
| } | |||
| } | |||
| TEST_F(TestPrim, test_relu) { | |||
| PrimitivePtr relu = prim::kPrimRelu; | |||
| relu->AddAttr("T", MakeValue(static_cast<int>(kNumberTypeFloat64))); | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test dtype and shape as attr""" | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_dtype_and_shape_as_attr(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| shape = x.shape | |||
| dtype = x.dtype | |||
| return shape, dtype | |||
| net = Net() | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| ret = net(x) | |||
| assert ret == ((1, 2, 3), mstype.int32) | |||
| def test_dtype_and_shape_as_attr_to_new_tensor(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, value): | |||
| super(Net, self).__init__() | |||
| self.fill = P.Fill() | |||
| self.value = value | |||
| def construct(self, x): | |||
| dtype = x.dtype | |||
| shape = x.shape | |||
| y = self.fill(dtype, shape, self.value) | |||
| return y | |||
| net = Net(2.2) | |||
| x = Tensor(np.ones([1, 2, 3], np.float32)) | |||
| ret = net(x) | |||
| assert (ret.asnumpy() == (np.zeros([1, 2, 3], np.float32) + 2.2)).all() | |||
| def test_type_not_have_the_attr(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| shape = x.shapes | |||
| return shape | |||
| net = Net() | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| with pytest.raises(RuntimeError) as ex: | |||
| net(x) | |||
| assert "The object of type: Tensor[Int32] has no method or attr: shapes" in str(ex.value) | |||
| def test_type_not_have_the_method(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| shape = x.dtypes() | |||
| return shape | |||
| net = Net() | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| with pytest.raises(RuntimeError) as ex: | |||
| net(x) | |||
| assert "The object of type: Tensor[Int32] has no method or attr: dtypes" in str(ex.value) | |||
| @@ -20,7 +20,7 @@ import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class FatherNet(nn.Cell): | |||
| @@ -92,7 +92,6 @@ class Net(nn.Cell): | |||
| def test_single_super(): | |||
| single_net = SingleSubNet(2, 3) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| single_net(x, y) | |||
| @@ -100,7 +99,6 @@ def test_single_super(): | |||
| def test_mul_super(): | |||
| mul_net = MulSubNet(2, 3, 4) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| mul_net(x, y) | |||
| @@ -108,7 +106,6 @@ def test_mul_super(): | |||
| def test_super_cell(): | |||
| net = Net(2) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| with pytest.raises(RuntimeError) as er: | |||
| @@ -142,7 +139,6 @@ def test_single_super_in(): | |||
| return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z | |||
| single_net_in = SingleSubNetIN(2, 3) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| single_net_in(x, y) | |||