| @@ -17,7 +17,7 @@ | |||||
| """Resources for ast tree parse.""" | """Resources for ast tree parse.""" | ||||
| import ast | import ast | ||||
| import math | import math | ||||
| from mindspore import IndexedSlices, SparseTensor | |||||
| from mindspore import RowTensor, SparseTensor | |||||
| from mindspore.ops.composite import multitype_ops | from mindspore.ops.composite import multitype_ops | ||||
| from mindspore.ops import functional as F, composite as C | from mindspore.ops import functional as F, composite as C | ||||
| from . import standard_method as M | from . import standard_method as M | ||||
| @@ -140,6 +140,6 @@ convert_object_map = { | |||||
| math.tan: NO_IMPLEMENT, | math.tan: NO_IMPLEMENT, | ||||
| # user defined | # user defined | ||||
| IndexedSlices: F.make_indexed_slices, | |||||
| RowTensor: F.make_row_tensor, | |||||
| SparseTensor: F.make_sparse_tensor, | SparseTensor: F.make_sparse_tensor, | ||||
| } | } | ||||
| @@ -120,7 +120,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s | |||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | ||||
| } | } | ||||
| } | } | ||||
| } else if (type->isa<IndexedSlicesType>()) { | |||||
| } else if (type->isa<RowTensorType>()) { | |||||
| // Do Nothing | // Do Nothing | ||||
| } else if (type->isa<UndeterminedType>()) { | } else if (type->isa<UndeterminedType>()) { | ||||
| // Do Nothing | // Do Nothing | ||||
| @@ -174,12 +174,11 @@ inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_Virtua | |||||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | ||||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | ||||
| // IndexedSlices | |||||
| inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices"); | |||||
| inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); | |||||
| inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); | |||||
| inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); | |||||
| inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices"); | |||||
| // RowTensor | |||||
| inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor"); | |||||
| inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues"); | |||||
| inline const PrimitivePtr kPrimRowTensorGetIndices = std::make_shared<Primitive>("RowTensorGetIndices"); | |||||
| inline const PrimitivePtr kPrimRowTensorGetDenseShape = std::make_shared<Primitive>("RowTensorGetDenseShape"); | |||||
| // SparseTensor | // SparseTensor | ||||
| inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor"); | inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor"); | ||||
| @@ -340,8 +340,8 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv | |||||
| return std::make_shared<AbstractScalar>(kAnyValue, kBool); | return std::make_shared<AbstractScalar>(kAnyValue, kBool); | ||||
| } | } | ||||
| AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | // Inputs: two tensors and a tuple. | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| CheckArgsSize(op_name, args_spec_list, 3); | CheckArgsSize(op_name, args_spec_list, 3); | ||||
| @@ -393,41 +393,41 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim | |||||
| << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; | << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; | ||||
| } | } | ||||
| } | } | ||||
| auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec); | |||||
| auto ret = std::make_shared<AbstractRowTensor>(values->element()->BuildType(), dense_shape_vec); | |||||
| ret->set_indices(indices); | ret->set_indices(indices); | ||||
| ret->set_values(values); | ret->set_values(values); | ||||
| ret->set_dense_shape(dense_shape); | ret->set_dense_shape(dense_shape); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | // Inputs: two tensors and a tuple. | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| CheckArgsSize(op_name, args_spec_list, 1); | CheckArgsSize(op_name, args_spec_list, 1); | ||||
| auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(indexed_slices->values()); | |||||
| return indexed_slices->values(); | |||||
| auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(row_tensor->values()); | |||||
| return row_tensor->values(); | |||||
| } | } | ||||
| AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | // Inputs: two tensors and a tuple. | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| CheckArgsSize(op_name, args_spec_list, 1); | CheckArgsSize(op_name, args_spec_list, 1); | ||||
| auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(indexed_slices->indices()); | |||||
| return indexed_slices->indices(); | |||||
| auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(row_tensor->indices()); | |||||
| return row_tensor->indices(); | |||||
| } | } | ||||
| AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | // Inputs: two tensors and a tuple. | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| CheckArgsSize(op_name, args_spec_list, 1); | CheckArgsSize(op_name, args_spec_list, 1); | ||||
| auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); | |||||
| return indexed_slices->dense_shape(); | |||||
| auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(row_tensor->dense_shape()); | |||||
| return row_tensor->dense_shape(); | |||||
| } | } | ||||
| AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -32,9 +32,9 @@ namespace opt { | |||||
| using mindspore::abstract::AbstractAttribute; | using mindspore::abstract::AbstractAttribute; | ||||
| using mindspore::abstract::AbstractClass; | using mindspore::abstract::AbstractClass; | ||||
| using mindspore::abstract::AbstractDictionary; | using mindspore::abstract::AbstractDictionary; | ||||
| using mindspore::abstract::AbstractIndexedSlices; | |||||
| using mindspore::abstract::AbstractJTagged; | using mindspore::abstract::AbstractJTagged; | ||||
| using mindspore::abstract::AbstractList; | using mindspore::abstract::AbstractList; | ||||
| using mindspore::abstract::AbstractRowTensor; | |||||
| using mindspore::abstract::AbstractScalar; | using mindspore::abstract::AbstractScalar; | ||||
| using mindspore::abstract::AbstractSparseTensor; | using mindspore::abstract::AbstractSparseTensor; | ||||
| using mindspore::abstract::AbstractTuple; | using mindspore::abstract::AbstractTuple; | ||||
| @@ -81,10 +81,10 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) { | |||||
| return std::make_shared<AbstractTuple>(abstract_list); | return std::make_shared<AbstractTuple>(abstract_list); | ||||
| } | } | ||||
| if (t->isa<AbstractIndexedSlices>()) { | |||||
| auto abs_indexed_slices = dyn_cast<AbstractIndexedSlices>(t); | |||||
| std::vector<AbstractBasePtr> abstract_list{abs_indexed_slices->indices(), abs_indexed_slices->values(), | |||||
| abs_indexed_slices->dense_shape()}; | |||||
| if (t->isa<AbstractRowTensor>()) { | |||||
| auto abs_row_tensor = dyn_cast<AbstractRowTensor>(t); | |||||
| std::vector<AbstractBasePtr> abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(), | |||||
| abs_row_tensor->dense_shape()}; | |||||
| return std::make_shared<AbstractTuple>(abstract_list); | return std::make_shared<AbstractTuple>(abstract_list); | ||||
| } | } | ||||
| @@ -455,16 +455,16 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager | |||||
| } else if (IsValueNode<ValueList>(node)) { | } else if (IsValueNode<ValueList>(node)) { | ||||
| new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>()); | new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>()); | ||||
| } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) || | } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) || | ||||
| IsPrimitiveCNode(node, prim::kPrimMakeIndexedSlices)) { | |||||
| IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) { | |||||
| new_node = ConvertMakeSparseToMakeTuple(cnode); | new_node = ConvertMakeSparseToMakeTuple(cnode); | ||||
| } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) || | } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) || | ||||
| IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetIndices)) { | |||||
| IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) { | |||||
| new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0); | new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0); | ||||
| } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) || | } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) || | ||||
| IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetValues)) { | |||||
| IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) { | |||||
| new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1); | new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1); | ||||
| } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) || | } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) || | ||||
| IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetDenseShape)) { | |||||
| IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) { | |||||
| new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2); | new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2); | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ | |||||
| #include "frontend/optimizer/irpass/transpose_eliminate.h" | #include "frontend/optimizer/irpass/transpose_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/value_based_eliminate.h" | #include "frontend/optimizer/irpass/value_based_eliminate.h" | ||||
| #include "frontend/optimizer/opt.h" | #include "frontend/optimizer/opt.h" | ||||
| #include "frontend/optimizer/irpass/indexed_slices_eliminate.h" | |||||
| #include "frontend/optimizer/irpass/row_tensor_eliminate.h" | |||||
| #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -157,10 +157,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| mark_interface_fusion_ = | mark_interface_fusion_ = | ||||
| MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | ||||
| // IndexedSlices Eliminate | |||||
| indexed_slices_eliminate_ = MakeSubstitution( | |||||
| std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate", | |||||
| {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); | |||||
| // RowTensor Eliminate | |||||
| row_tensor_eliminate_ = MakeSubstitution( | |||||
| std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate", | |||||
| {prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape}); | |||||
| // SparseTensor Eliminate | // SparseTensor Eliminate | ||||
| sparse_tensor_eliminate_ = MakeSubstitution( | sparse_tensor_eliminate_ = MakeSubstitution( | ||||
| @@ -105,8 +105,8 @@ class OptimizeIRPassLib { | |||||
| // Fusion | // Fusion | ||||
| SubstitutionPtr mark_interface_fusion_; | SubstitutionPtr mark_interface_fusion_; | ||||
| // IndexedSlices Eliminate | |||||
| SubstitutionPtr indexed_slices_eliminate_; | |||||
| // RowTensor Eliminate | |||||
| SubstitutionPtr row_tensor_eliminate_; | |||||
| // SparseTensor Eliminate | // SparseTensor Eliminate | ||||
| SubstitutionPtr sparse_tensor_eliminate_; | SubstitutionPtr sparse_tensor_eliminate_; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| @@ -28,24 +28,24 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| // {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} | |||||
| // {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} | |||||
| // {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} | |||||
| class IndexedSlicesEliminater : public AnfVisitor { | |||||
| // {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}} | |||||
| // {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}} | |||||
| // {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}} | |||||
| class RowTensorEliminater : public AnfVisitor { | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| Reset(); | Reset(); | ||||
| AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); | |||||
| AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node); | |||||
| if (is_match_) { | if (is_match_) { | ||||
| return tuple_->input(1); | return tuple_->input(1); | ||||
| } | } | ||||
| AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); | |||||
| AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node); | |||||
| if (is_match_) { | if (is_match_) { | ||||
| return tuple_->input(2); | return tuple_->input(2); | ||||
| } | } | ||||
| AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); | |||||
| AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node); | |||||
| if (is_match_) { | if (is_match_) { | ||||
| return tuple_->input(3); | return tuple_->input(3); | ||||
| @@ -54,7 +54,7 @@ class IndexedSlicesEliminater : public AnfVisitor { | |||||
| } | } | ||||
| void Visit(const CNodePtr &cnode) override { | void Visit(const CNodePtr &cnode) override { | ||||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) { | |||||
| tuple_ = cnode; | tuple_ = cnode; | ||||
| is_match_ = true; | is_match_ = true; | ||||
| } | } | ||||
| @@ -72,4 +72,4 @@ class IndexedSlicesEliminater : public AnfVisitor { | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_ | |||||
| @@ -170,7 +170,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| irpass.make_ref_eliminate_, | irpass.make_ref_eliminate_, | ||||
| irpass.get_ref_param_eliminate_, | irpass.get_ref_param_eliminate_, | ||||
| irpass.indexed_slices_eliminate_, | |||||
| irpass.row_tensor_eliminate_, | |||||
| }); | }); | ||||
| OptPassGroupMap map({ | OptPassGroupMap map({ | ||||
| {"b_1", b_1}, | {"b_1", b_1}, | ||||
| @@ -30,153 +30,165 @@ namespace mindspore { | |||||
| namespace pipeline { | namespace pipeline { | ||||
| BuiltInTypeMap &GetMethodMap() { | BuiltInTypeMap &GetMethodMap() { | ||||
| static BuiltInTypeMap method_map = { | |||||
| {kObjectTypeString, | |||||
| { | |||||
| {"__bool__", std::string("str_bool")} // C.str_bool | |||||
| }}, | |||||
| {kMetaTypeNone, | |||||
| { | |||||
| {"__bool__", std::string("none_bool")} // C.none_bool | |||||
| }}, | |||||
| {kNumberTypeBool, | |||||
| { | |||||
| {"__and__", prim::kPrimBoolAnd}, // P.bool_and | |||||
| {"__or__", prim::kPrimBoolOr}, // P.bool_or | |||||
| {"__eq__", prim::kPrimBoolEq}, // P.bool_eq | |||||
| {"__ne__", std::string("bool_ne")}, // C.bool_ne | |||||
| {"__bool__", prim::kPrimIdentity} // P.identity | |||||
| }}, | |||||
| {kNumberTypeInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul | |||||
| {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array | |||||
| }}, | |||||
| {kNumberTypeUInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kNumberTypeFloat, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv | |||||
| {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, | |||||
| {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("float_bool")}, // C.float_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kObjectTypeTuple, | |||||
| { | |||||
| {"__len__", prim::kPrimTupleLen}, // P.tuple_len, | |||||
| {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, | |||||
| {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, | |||||
| {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext | |||||
| {"__bool__", std::string("tuple_bool")} // C.tuple_bool | |||||
| }}, | |||||
| {kObjectTypeList, | |||||
| { | |||||
| {"__len__", prim::kPrimListLen}, // P.list_len, | |||||
| {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, | |||||
| {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity | |||||
| {"__ms_next__", std::string("list_next")}, // C.list_next | |||||
| {"append", std::string("list_append")}, // C.list_next | |||||
| {"__bool__", std::string("list_bool")}, // C.list_bool | |||||
| {"__ms_hasnext__", std::string("list_hasnext")}, | |||||
| }}, | |||||
| {kObjectTypeDictionary, | |||||
| { | |||||
| {"__len__", prim::kPrimDictLen}, // P.dict_len | |||||
| {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem | |||||
| {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, | |||||
| {"__bool__", std::string("dict_bool")} // C.dict_bool | |||||
| }}, | |||||
| static BuiltInTypeMap method_map = {{kObjectTypeString, | |||||
| { | |||||
| {"__bool__", std::string("str_bool")} // C.str_bool | |||||
| }}, | |||||
| {kMetaTypeNone, | |||||
| { | |||||
| {"__bool__", std::string("none_bool")} // C.none_bool | |||||
| }}, | |||||
| {kNumberTypeBool, | |||||
| { | |||||
| {"__and__", prim::kPrimBoolAnd}, // P.bool_and | |||||
| {"__or__", prim::kPrimBoolOr}, // P.bool_or | |||||
| {"__eq__", prim::kPrimBoolEq}, // P.bool_eq | |||||
| {"__ne__", std::string("bool_ne")}, // C.bool_ne | |||||
| {"__bool__", prim::kPrimIdentity} // P.identity | |||||
| }}, | |||||
| {kNumberTypeInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul | |||||
| {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array | |||||
| }}, | |||||
| {kNumberTypeUInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kNumberTypeFloat, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv | |||||
| {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, | |||||
| {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("float_bool")}, // C.float_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kObjectTypeTuple, | |||||
| { | |||||
| {"__len__", prim::kPrimTupleLen}, // P.tuple_len, | |||||
| {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, | |||||
| {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, | |||||
| {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext | |||||
| {"__bool__", std::string("tuple_bool")} // C.tuple_bool | |||||
| }}, | |||||
| {kObjectTypeList, | |||||
| { | |||||
| {"__len__", prim::kPrimListLen}, // P.list_len, | |||||
| {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, | |||||
| {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity | |||||
| {"__ms_next__", std::string("list_next")}, // C.list_next | |||||
| {"append", std::string("list_append")}, // C.list_next | |||||
| {"__bool__", std::string("list_bool")}, // C.list_bool | |||||
| {"__ms_hasnext__", std::string("list_hasnext")}, | |||||
| }}, | |||||
| {kObjectTypeDictionary, | |||||
| { | |||||
| {"__len__", prim::kPrimDictLen}, // P.dict_len | |||||
| {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem | |||||
| {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, | |||||
| {"__bool__", std::string("dict_bool")} // C.dict_bool | |||||
| }}, | |||||
| {kObjectTypeTensorType, | |||||
| { | |||||
| {"all", std::string("all_")}, // C.reduce_all | |||||
| {"any", std::string("any_")}, // C.reduce_any | |||||
| {"__add__", std::string("add")}, // C.add | |||||
| {"__sub__", std::string("sub")}, // C.sub | |||||
| {"__mul__", std::string("mul")}, // C.mul | |||||
| {"__truediv__", std::string("truediv")}, // C.truediv | |||||
| {"__floordiv__", std::string("floordiv")}, // C.floordiv | |||||
| {"__mod__", std::string("mod")}, // C.mod | |||||
| {"__pow__", std::string("pow_")}, // C.pow | |||||
| {"__floor__", std::string("array_floor")}, // C.array_floor | |||||
| {"__trunc__", std::string("array_trunc")}, // C.array_trunc | |||||
| {"__pos__", std::string("array_uadd")}, // C.array_uadd | |||||
| {"__neg__", std::string("array_usub")}, // C.array_usub | |||||
| {"__eq__", std::string("eq")}, // C.eq | |||||
| {"__ne__", std::string("ne")}, // C.ne | |||||
| {"__lt__", std::string("lt")}, // C.lt | |||||
| {"__gt__", std::string("gt")}, // C.gt | |||||
| {"__le__", std::string("le")}, // C.le | |||||
| {"__ge__", std::string("ge")}, // C.ge | |||||
| {"__matmul__", prim::kPrimDot}, // P.dot, | |||||
| {"__len__", prim::kPrimArrayLen}, // P.array_len, | |||||
| {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, | |||||
| {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, | |||||
| {"__ms_iter__", std::string("array_iter")}, // C.array_iter | |||||
| {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, | |||||
| {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, | |||||
| {"transpose", std::string("transpose")}, // P.transpose | |||||
| {"__bool__", std::string("tensor_bool")}, // C.tensor_bool | |||||
| }}, | |||||
| {kObjectTypeJTagged, {}}, | |||||
| {kObjectTypeSymbolicKeyType, {}}, | |||||
| {kObjectTypeEnvType, {}}}; | |||||
| return method_map; | |||||
| } | |||||
| BuiltInTypeMap &GetAttrMap() { | |||||
| static BuiltInTypeMap attr_map = { | |||||
| {kObjectTypeTensorType, | {kObjectTypeTensorType, | ||||
| { | { | ||||
| {"all", std::string("all_")}, // C.reduce_all | |||||
| {"any", std::string("any_")}, // C.reduce_any | |||||
| {"__add__", std::string("add")}, // C.add | |||||
| {"__sub__", std::string("sub")}, // C.sub | |||||
| {"__mul__", std::string("mul")}, // C.mul | |||||
| {"__truediv__", std::string("truediv")}, // C.truediv | |||||
| {"__floordiv__", std::string("floordiv")}, // C.floordiv | |||||
| {"__mod__", std::string("mod")}, // C.mod | |||||
| {"__pow__", std::string("pow_")}, // C.pow | |||||
| {"__floor__", std::string("array_floor")}, // C.array_floor | |||||
| {"__trunc__", std::string("array_trunc")}, // C.array_trunc | |||||
| {"__pos__", std::string("array_uadd")}, // C.array_uadd | |||||
| {"__neg__", std::string("array_usub")}, // C.array_usub | |||||
| {"__eq__", std::string("eq")}, // C.eq | |||||
| {"__ne__", std::string("ne")}, // C.ne | |||||
| {"__lt__", std::string("lt")}, // C.lt | |||||
| {"__gt__", std::string("gt")}, // C.gt | |||||
| {"__le__", std::string("le")}, // C.le | |||||
| {"__ge__", std::string("ge")}, // C.ge | |||||
| {"__matmul__", prim::kPrimDot}, // P.dot, | |||||
| {"__len__", prim::kPrimArrayLen}, // P.array_len, | |||||
| {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, | |||||
| {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, | |||||
| {"__ms_iter__", std::string("array_iter")}, // C.array_iter | |||||
| {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, | |||||
| {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, | |||||
| {"transpose", std::string("transpose")}, // P.transpose | |||||
| {"__bool__", std::string("tensor_bool")}, // C.tensor_bool | |||||
| {"shape", std::string("shape_")}, // C.shape_ | |||||
| {"dtype", std::string("dtype_")}, // C.dtype_ | |||||
| }}, | }}, | ||||
| {kObjectTypeIndexedSlicesType, | |||||
| {kObjectTypeRowTensorType, | |||||
| { | { | ||||
| {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values | |||||
| {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices | |||||
| {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape | |||||
| {"values", prim::kPrimRowTensorGetValues}, // F.row_tensor_get_values | |||||
| {"indices", prim::kPrimRowTensorGetIndices}, // F.row_tensor_get_indices | |||||
| {"dense_shape", prim::kPrimRowTensorGetDenseShape}, // F.row_tensor_get_dense_shape | |||||
| }}, | }}, | ||||
| {kObjectTypeSparseTensorType, | {kObjectTypeSparseTensorType, | ||||
| { | { | ||||
| @@ -184,18 +196,7 @@ BuiltInTypeMap &GetMethodMap() { | |||||
| {"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices | {"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices | ||||
| {"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape | {"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape | ||||
| }}, | }}, | ||||
| {kObjectTypeJTagged, {}}, | |||||
| {kObjectTypeSymbolicKeyType, {}}, | |||||
| {kObjectTypeEnvType, {}}}; | |||||
| return method_map; | |||||
| } | |||||
| BuiltInTypeMap &GetAttrMap() { | |||||
| static BuiltInTypeMap attr_map = {{kObjectTypeTensorType, | |||||
| { | |||||
| {"shape", std::string("shape_")}, // C.shape_ | |||||
| {"dtype", std::string("dtype_")}, // C.dtype_ | |||||
| }}}; | |||||
| }; | |||||
| return attr_map; | return attr_map; | ||||
| } | } | ||||
| @@ -132,11 +132,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | ||||
| // Debug | // Debug | ||||
| {prim::kPrimDebug, {InferImplDebug, true}}, | {prim::kPrimDebug, {InferImplDebug, true}}, | ||||
| // IndexedSlices | |||||
| {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, | |||||
| {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, | |||||
| {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, | |||||
| {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, | |||||
| // RowTensor | |||||
| {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, | |||||
| {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, | |||||
| {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, | |||||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | |||||
| // SparseTensor | // SparseTensor | ||||
| {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, | {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, | ||||
| {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, | {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, | ||||
| @@ -402,8 +402,8 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| } | } | ||||
| dic["dtype"] = arg_tensor->BuildType(); | dic["dtype"] = arg_tensor->BuildType(); | ||||
| dic["value"] = BuildValue(arg_tensor->BuildValue()); | dic["value"] = BuildValue(arg_tensor->BuildValue()); | ||||
| } else if (abs_base->isa<AbstractIndexedSlices>()) { | |||||
| auto arg = dyn_cast<AbstractIndexedSlices>(abs_base); | |||||
| } else if (abs_base->isa<AbstractRowTensor>()) { | |||||
| auto arg = dyn_cast<AbstractRowTensor>(abs_base); | |||||
| dic["shape"] = arg->shape()->shape(); | dic["shape"] = arg->shape()->shape(); | ||||
| dic["dtype"] = arg->BuildType(); | dic["dtype"] = arg->BuildType(); | ||||
| dic["value"] = BuildValue(arg->BuildValue()); | dic["value"] = BuildValue(arg->BuildValue()); | ||||
| @@ -348,14 +348,14 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv | |||||
| AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -32,9 +32,9 @@ using mindspore::abstract::AbstractBase; | |||||
| using mindspore::abstract::AbstractClass; | using mindspore::abstract::AbstractClass; | ||||
| using mindspore::abstract::AbstractError; | using mindspore::abstract::AbstractError; | ||||
| using mindspore::abstract::AbstractFunction; | using mindspore::abstract::AbstractFunction; | ||||
| using mindspore::abstract::AbstractIndexedSlices; | |||||
| using mindspore::abstract::AbstractJTagged; | using mindspore::abstract::AbstractJTagged; | ||||
| using mindspore::abstract::AbstractList; | using mindspore::abstract::AbstractList; | ||||
| using mindspore::abstract::AbstractRowTensor; | |||||
| using mindspore::abstract::AbstractScalar; | using mindspore::abstract::AbstractScalar; | ||||
| using mindspore::abstract::AbstractSparseTensor; | using mindspore::abstract::AbstractSparseTensor; | ||||
| using mindspore::abstract::AbstractTensor; | using mindspore::abstract::AbstractTensor; | ||||
| @@ -95,7 +95,7 @@ void ValidateAbstract(const AnfNodePtr &node) { | |||||
| } | } | ||||
| if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || | if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || | ||||
| ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() || | |||||
| ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() || | |||||
| ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) { | ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -136,8 +136,7 @@ REGISTER_PYBIND_DEFINE( | |||||
| TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); | TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); | ||||
| return data; | return data; | ||||
| })); | })); | ||||
| (void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType") | |||||
| .def(py::init()); | |||||
| (void)py::class_<RowTensorType, Type, std::shared_ptr<RowTensorType>>(m_sub, "RowTensorType").def(py::init()); | |||||
| (void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType") | (void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType") | ||||
| .def(py::init()); | .def(py::init()); | ||||
| (void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType") | (void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType") | ||||
| @@ -17,10 +17,10 @@ from . import dtype | |||||
| from .api import ms_function | from .api import ms_function | ||||
| from .dtype import * | from .dtype import * | ||||
| from .parameter import Parameter, ParameterTuple | from .parameter import Parameter, ParameterTuple | ||||
| from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor | |||||
| from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor | |||||
| __all__ = [ | __all__ = [ | ||||
| "MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor | |||||
| "MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor | |||||
| 'ms_function', # api | 'ms_function', # api | ||||
| 'Parameter', 'ParameterTuple', # parameter | 'Parameter', 'ParameterTuple', # parameter | ||||
| "dtype" | "dtype" | ||||
| @@ -99,7 +99,7 @@ slice_type = typing.Slice | |||||
| ellipsis_type = typing.TypeEllipsis | ellipsis_type = typing.TypeEllipsis | ||||
| list_type = typing.List | list_type = typing.List | ||||
| tuple_type = typing.Tuple | tuple_type = typing.Tuple | ||||
| index_slices = typing.IndexedSlicesType() | |||||
| index_slices = typing.RowTensorType() | |||||
| sparse_tensor = typing.SparseTensorType() | sparse_tensor = typing.SparseTensorType() | ||||
| undetermined = typing.UndeterminedType() | undetermined = typing.UndeterminedType() | ||||
| @@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename | |||||
| from . import dtype as mstype | from . import dtype as mstype | ||||
| from ._register_for_tensor import tensor_operator_registry | from ._register_for_tensor import tensor_operator_registry | ||||
| __all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor'] | |||||
| __all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor'] | |||||
| np_types = (np.int8, np.int16, np.int32, np.int64, | np_types = (np.int8, np.int16, np.int32, np.int64, | ||||
| np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | ||||
| np.float32, np.float64, np.bool_) | np.float32, np.float64, np.bool_) | ||||
| @@ -267,20 +267,20 @@ class Tensor(Tensor_): | |||||
| return tensor_operator_registry.get('any')(keep_dims)(self, axis) | return tensor_operator_registry.get('any')(keep_dims)(self, axis) | ||||
| class IndexedSlices: | |||||
| class RowTensor: | |||||
| """ | """ | ||||
| A sparse representation of a set of tensor slices at given indices. | A sparse representation of a set of tensor slices at given indices. | ||||
| An IndexedSlices is typically used to represent a subset of a larger | |||||
| An RowTensor is typically used to represent a subset of a larger | |||||
| tensor dense of shape [L0, D1, .. , DN] where L0 >> D0. | tensor dense of shape [L0, D1, .. , DN] where L0 >> D0. | ||||
| The values in indices are the indices in the first dimension of the slices | The values in indices are the indices in the first dimension of the slices | ||||
| that have been extracted from the larger tensor. | that have been extracted from the larger tensor. | ||||
| The dense tensor dense represented by an IndexedSlices slices has | |||||
| The dense tensor dense represented by an RowTensor slices has | |||||
| `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`. | `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`. | ||||
| IndexedSlices can only be used in the `Cell`'s construct method. | |||||
| RowTensor can only be used in the `Cell`'s contruct method. | |||||
| It is not supported in pynative mode at the moment. | It is not supported in pynative mode at the moment. | ||||
| @@ -291,7 +291,7 @@ class IndexedSlices: | |||||
| of the corresponding dense tensor. | of the corresponding dense tensor. | ||||
| Returns: | Returns: | ||||
| IndexedSlices, composed of `indices`, `values`, and `dense_shape`. | |||||
| RowTensor, composed of `indices`, `values`, and `dense_shape`. | |||||
| Examples: | Examples: | ||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| @@ -299,8 +299,8 @@ class IndexedSlices: | |||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| >>> self.dense_shape = dense_shape | >>> self.dense_shape = dense_shape | ||||
| >>> def construct(self, indices, values): | >>> def construct(self, indices, values): | ||||
| >>> x = IndexedSlices(indices, values, self.dense_shape) | |||||
| >>> return x.values(), x.indices(), x.dense_shape() | |||||
| >>> x = RowTensor(indices, values, self.dense_shape) | |||||
| >>> return x.values, x.indices, x.dense_shape | |||||
| >>> | >>> | ||||
| >>> indices = Tensor([0]) | >>> indices = Tensor([0]) | ||||
| >>> values = Tensor([[1, 2]], dtype=ms.float32) | >>> values = Tensor([[1, 2]], dtype=ms.float32) | ||||
| @@ -308,17 +308,20 @@ class IndexedSlices: | |||||
| """ | """ | ||||
| def __init__(self, indices, values, dense_shape): | def __init__(self, indices, values, dense_shape): | ||||
| "Init IndexedSlices" | |||||
| "Init RowTensor" | |||||
| self.__indices = indices | self.__indices = indices | ||||
| self.__values = values | self.__values = values | ||||
| self.__dense_shape = dense_shape | self.__dense_shape = dense_shape | ||||
| @property | |||||
| def indices(self): | def indices(self): | ||||
| return self.__indices | return self.__indices | ||||
| @property | |||||
| def values(self): | def values(self): | ||||
| return self.__values | return self.__values | ||||
| @property | |||||
| def dense_shape(self): | def dense_shape(self): | ||||
| return self.__dense_shape | return self.__dense_shape | ||||
| @@ -353,7 +356,7 @@ class SparseTensor: | |||||
| >>> self.dense_shape = dense_shape | >>> self.dense_shape = dense_shape | ||||
| >>> def construct(self, indices, values): | >>> def construct(self, indices, values): | ||||
| >>> x = SparseTensor(indices, values, self.dense_shape) | >>> x = SparseTensor(indices, values, self.dense_shape) | ||||
| >>> return x.values(), x.indices(), x.dense_shape() | |||||
| >>> return x.values, x.indices, x.dense_shape | |||||
| >>> | >>> | ||||
| >>> indices = Tensor([[0, 1], [1, 2]]) | >>> indices = Tensor([[0, 1], [1, 2]]) | ||||
| >>> values = Tensor([1, 2], dtype=ms.float32) | >>> values = Tensor([1, 2], dtype=ms.float32) | ||||
| @@ -366,11 +369,14 @@ class SparseTensor: | |||||
| self.__values = values | self.__values = values | ||||
| self.__dense_shape = dense_shape | self.__dense_shape = dense_shape | ||||
| @property | |||||
| def indices(self): | def indices(self): | ||||
| return self.__indices | return self.__indices | ||||
| @property | |||||
| def values(self): | def values(self): | ||||
| return self.__values | return self.__values | ||||
| @property | |||||
| def dense_shape(self): | def dense_shape(self): | ||||
| return self.__dense_shape | return self.__dense_shape | ||||
| @@ -1050,16 +1050,16 @@ bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const | |||||
| return AbstractBasePtrListDeepEqual(lhs, rhs); | return AbstractBasePtrListDeepEqual(lhs, rhs); | ||||
| } | } | ||||
| // IndexedSlices | |||||
| TypePtr AbstractIndexedSlices::BuildType() const { | |||||
| // RowTensor | |||||
| TypePtr AbstractRowTensor::BuildType() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | MS_EXCEPTION_IF_NULL(element()); | ||||
| TypePtr element_type = element()->BuildType(); | TypePtr element_type = element()->BuildType(); | ||||
| return std::make_shared<IndexedSlicesType>(element_type); | |||||
| return std::make_shared<RowTensorType>(element_type); | |||||
| } | } | ||||
| AbstractBasePtr AbstractIndexedSlices::Clone() const { | |||||
| AbstractBasePtr AbstractRowTensor::Clone() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | MS_EXCEPTION_IF_NULL(element()); | ||||
| auto clone = std::make_shared<AbstractIndexedSlices>(element()->Clone()); | |||||
| auto clone = std::make_shared<AbstractRowTensor>(element()->Clone()); | |||||
| ShapePtr shp = shape(); | ShapePtr shp = shape(); | ||||
| clone->set_shape(shp->Clone()); | clone->set_shape(shp->Clone()); | ||||
| clone->set_value(GetValueTrack()); | clone->set_value(GetValueTrack()); | ||||
| @@ -1069,9 +1069,9 @@ AbstractBasePtr AbstractIndexedSlices::Clone() const { | |||||
| return clone; | return clone; | ||||
| } | } | ||||
| AbstractBasePtr AbstractIndexedSlices::Broaden() const { | |||||
| AbstractBasePtr AbstractRowTensor::Broaden() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | MS_EXCEPTION_IF_NULL(element()); | ||||
| auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden()); | |||||
| auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden()); | |||||
| auto shp = shape(); | auto shp = shape(); | ||||
| broaden->set_shape(shp->Clone()); | broaden->set_shape(shp->Clone()); | ||||
| broaden->set_value(kAnyValue); | broaden->set_value(kAnyValue); | ||||
| @@ -1081,9 +1081,9 @@ AbstractBasePtr AbstractIndexedSlices::Broaden() const { | |||||
| return broaden; | return broaden; | ||||
| } | } | ||||
| AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { | |||||
| AbstractBasePtr AbstractRowTensor::BroadenWithShape() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | MS_EXCEPTION_IF_NULL(element()); | ||||
| auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden()); | |||||
| auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden()); | |||||
| auto shp = shape()->Clone(); | auto shp = shape()->Clone(); | ||||
| shp->Broaden(); | shp->Broaden(); | ||||
| broaden->set_shape(shp); | broaden->set_shape(shp); | ||||
| @@ -1094,7 +1094,7 @@ AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { | |||||
| return broaden; | return broaden; | ||||
| } | } | ||||
| std::string AbstractIndexedSlices::ToString() const { | |||||
| std::string AbstractRowTensor::ToString() const { | |||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| BaseShapePtr shape_track = GetShapeTrack(); | BaseShapePtr shape_track = GetShapeTrack(); | ||||
| MS_EXCEPTION_IF_NULL(shape_track); | MS_EXCEPTION_IF_NULL(shape_track); | ||||
| @@ -593,15 +593,15 @@ struct AbstractBasePtrListEqual { | |||||
| std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); | std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); | ||||
| bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); | bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); | ||||
| // IndexedSlices | |||||
| class AbstractIndexedSlices : public AbstractUndetermined { | |||||
| // RowTensor | |||||
| class AbstractRowTensor : public AbstractUndetermined { | |||||
| public: | public: | ||||
| explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) | |||||
| explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) | |||||
| : AbstractUndetermined(element, shape) {} | : AbstractUndetermined(element, shape) {} | ||||
| AbstractIndexedSlices(const TypePtr &element_type, const std::vector<int> &shape) | |||||
| AbstractRowTensor(const TypePtr &element_type, const std::vector<int> &shape) | |||||
| : AbstractUndetermined(element_type, shape) {} | : AbstractUndetermined(element_type, shape) {} | ||||
| ~AbstractIndexedSlices() override = default; | |||||
| MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) | |||||
| ~AbstractRowTensor() override = default; | |||||
| MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined) | |||||
| const AbstractTensorPtr indices() const { return indices_; } | const AbstractTensorPtr indices() const { return indices_; } | ||||
| void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } | void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } | ||||
| @@ -66,7 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function) | |||||
| ABSTRACT_REPORT_NAME_TRAITS(Type) | ABSTRACT_REPORT_NAME_TRAITS(Type) | ||||
| ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) | ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) | ||||
| ABSTRACT_REPORT_NAME_TRAITS(Class) | ABSTRACT_REPORT_NAME_TRAITS(Class) | ||||
| ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) | |||||
| ABSTRACT_REPORT_NAME_TRAITS(RowTensor) | |||||
| ABSTRACT_REPORT_NAME_TRAITS(SparseTensor) | ABSTRACT_REPORT_NAME_TRAITS(SparseTensor) | ||||
| ABSTRACT_REPORT_NAME_TRAITS(Sequeue) | ABSTRACT_REPORT_NAME_TRAITS(Sequeue) | ||||
| @@ -179,40 +179,40 @@ bool TensorType::operator==(const Type &other) const { | |||||
| return *element_type_ == *other_elem_type; | return *element_type_ == *other_elem_type; | ||||
| } | } | ||||
| TypePtr IndexedSlicesType::DeepCopy() const { | |||||
| TypePtr RowTensorType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | MS_EXCEPTION_IF_NULL(element_type_); | ||||
| if (IsGeneric()) { | if (IsGeneric()) { | ||||
| return std::make_shared<IndexedSlicesType>(); | |||||
| return std::make_shared<RowTensorType>(); | |||||
| } | } | ||||
| return std::make_shared<IndexedSlicesType>(element_type_->DeepCopy()); | |||||
| return std::make_shared<RowTensorType>(element_type_->DeepCopy()); | |||||
| } | } | ||||
| std::string IndexedSlicesType::ToReprString() const { | |||||
| std::string RowTensorType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | if (element_type_ == nullptr) { | ||||
| return "IndexedSlices"; | |||||
| return "RowTensor"; | |||||
| } | } | ||||
| return "IndexedSlices[" + element_type_->ToReprString() + "]"; | |||||
| return "RowTensor[" + element_type_->ToReprString() + "]"; | |||||
| } | } | ||||
| std::string IndexedSlicesType::ToString() const { | |||||
| std::string RowTensorType::ToString() const { | |||||
| if (element_type_ == nullptr) { | if (element_type_ == nullptr) { | ||||
| return "IndexedSlices"; | |||||
| return "RowTensor"; | |||||
| } | } | ||||
| return "IndexedSlices[" + element_type_->ToString() + "]"; | |||||
| return "RowTensor[" + element_type_->ToString() + "]"; | |||||
| } | } | ||||
| std::string IndexedSlicesType::DumpText() const { | |||||
| std::string RowTensorType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | if (element_type_ == nullptr) { | ||||
| return "IndexedSlices"; | |||||
| return "RowTensor"; | |||||
| } | } | ||||
| return "IndexedSlices[" + element_type_->DumpText() + "]"; | |||||
| return "RowTensor[" + element_type_->DumpText() + "]"; | |||||
| } | } | ||||
| bool IndexedSlicesType::operator==(const Type &other) const { | |||||
| bool RowTensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto other_elem_type = static_cast<const IndexedSlicesType &>(other).element_type_; | |||||
| auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | if (element_type_ == nullptr && other_elem_type == nullptr) { | ||||
| return true; | return true; | ||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | } else if (element_type_ == nullptr || other_elem_type == nullptr) { | ||||
| @@ -154,15 +154,15 @@ class TensorType : public Object { | |||||
| }; | }; | ||||
| using TensorTypePtr = std::shared_ptr<TensorType>; | using TensorTypePtr = std::shared_ptr<TensorType>; | ||||
| class IndexedSlicesType : public Object { | |||||
| class RowTensorType : public Object { | |||||
| public: | public: | ||||
| IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {} | |||||
| explicit IndexedSlicesType(const TypePtr &ele) | |||||
| : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~IndexedSlicesType() override = default; | |||||
| MS_DECLARE_PARENT(IndexedSlicesType, Object) | |||||
| RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit RowTensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~RowTensorType() override = default; | |||||
| MS_DECLARE_PARENT(RowTensorType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; } | |||||
| TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } | |||||
| const TypePtr element() const { return element_type_; } | const TypePtr element() const { return element_type_; } | ||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | void set_element(const TypePtr &element_type) { element_type_ = element_type; } | ||||
| @@ -175,7 +175,7 @@ class IndexedSlicesType : public Object { | |||||
| private: | private: | ||||
| TypePtr element_type_; | TypePtr element_type_; | ||||
| }; | }; | ||||
| using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>; | |||||
| using RowTensorTypePtr = std::shared_ptr<RowTensorType>; | |||||
| class SparseTensorType : public Object { | class SparseTensorType : public Object { | ||||
| public: | public: | ||||
| @@ -115,8 +115,8 @@ const char *ObjectIdLabel(const TypeId &v) { | |||||
| return "kObjectTypeKeyword"; | return "kObjectTypeKeyword"; | ||||
| case kObjectTypeTensorType: | case kObjectTypeTensorType: | ||||
| return "kObjectTypeTensorType"; | return "kObjectTypeTensorType"; | ||||
| case kObjectTypeIndexedSlicesType: | |||||
| return "kObjectTypeIndexedSlicesType"; | |||||
| case kObjectTypeRowTensorType: | |||||
| return "kObjectTypeRowTensorType"; | |||||
| case kObjectTypeSparseTensorType: | case kObjectTypeSparseTensorType: | ||||
| return "kObjectTypeSparseTensorType"; | return "kObjectTypeSparseTensorType"; | ||||
| case kObjectTypeUndeterminedType: | case kObjectTypeUndeterminedType: | ||||
| @@ -50,7 +50,7 @@ enum TypeId : int { | |||||
| kObjectTypeSlice, | kObjectTypeSlice, | ||||
| kObjectTypeKeyword, | kObjectTypeKeyword, | ||||
| kObjectTypeTensorType, | kObjectTypeTensorType, | ||||
| kObjectTypeIndexedSlicesType, | |||||
| kObjectTypeRowTensorType, | |||||
| kObjectTypeSparseTensorType, | kObjectTypeSparseTensorType, | ||||
| kObjectTypeUndeterminedType, | kObjectTypeUndeterminedType, | ||||
| kObjectTypeClass, | kObjectTypeClass, | ||||
| @@ -190,9 +190,9 @@ TypePtr TensorStrToType(const std::string &type_name) { | |||||
| return type; | return type; | ||||
| } | } | ||||
| TypePtr IndexedSlicesStrToType(const std::string &type_name) { | |||||
| if (type_name == "IndexedSlices") { | |||||
| return std::make_shared<IndexedSlicesType>(); | |||||
| TypePtr RowTensorStrToType(const std::string &type_name) { | |||||
| if (type_name == "RowTensor") { | |||||
| return std::make_shared<RowTensorType>(); | |||||
| } | } | ||||
| auto start = type_name.find_first_of('[') + 1; | auto start = type_name.find_first_of('[') + 1; | ||||
| auto end = type_name.find_last_of(']'); | auto end = type_name.find_last_of(']'); | ||||
| @@ -204,7 +204,7 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) { | |||||
| if (element_type == nullptr) { | if (element_type == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return std::make_shared<IndexedSlicesType>(element_type); | |||||
| return std::make_shared<RowTensorType>(element_type); | |||||
| } | } | ||||
| TypePtr SparseTensorStrToType(const std::string &type_name) { | TypePtr SparseTensorStrToType(const std::string &type_name) { | ||||
| @@ -364,8 +364,8 @@ TypePtr StringToType(const std::string &type_name) { | |||||
| type = TensorStrToType(type_name); | type = TensorStrToType(type_name); | ||||
| } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { | } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { | ||||
| type = UndeterminedStrToType(type_name); | type = UndeterminedStrToType(type_name); | ||||
| } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { | |||||
| type = IndexedSlicesStrToType(type_name); | |||||
| } else if (type_name.compare(0, strlen("RowTensor"), "RowTensor") == 0) { | |||||
| type = RowTensorStrToType(type_name); | |||||
| } else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) { | } else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) { | ||||
| type = SparseTensorStrToType(type_name); | type = SparseTensorStrToType(type_name); | ||||
| } else if (type_name.compare(0, strlen("List"), "List") == 0) { | } else if (type_name.compare(0, strlen("List"), "List") == 0) { | ||||
| @@ -446,7 +446,7 @@ const TypePtr kTypeExternal = std::make_shared<External>(); | |||||
| const TypePtr kTypeEnv = std::make_shared<EnvType>(); | const TypePtr kTypeEnv = std::make_shared<EnvType>(); | ||||
| const TypePtr kTypeType = std::make_shared<TypeType>(); | const TypePtr kTypeType = std::make_shared<TypeType>(); | ||||
| const TypePtr kTensorType = std::make_shared<TensorType>(); | const TypePtr kTensorType = std::make_shared<TensorType>(); | ||||
| const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>(); | |||||
| const TypePtr kRowTensorType = std::make_shared<RowTensorType>(); | |||||
| const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>(); | const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>(); | ||||
| const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>(); | const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>(); | ||||
| const TypePtr kString = std::make_shared<String>(); | const TypePtr kString = std::make_shared<String>(); | ||||
| @@ -85,13 +85,13 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d | |||||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | ||||
| "Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor", "Bool") | |||||
| "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, | def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, | ||||
| gradient, params, moment1, moment2, ps_parameter): | gradient, params, moment1, moment2, ps_parameter): | ||||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | ||||
| success = True | success = True | ||||
| indices = gradient.indices() | |||||
| values = gradient.values() | |||||
| indices = gradient.indices | |||||
| values = gradient.values | |||||
| if ps_parameter: | if ps_parameter: | ||||
| op_shape = P.Shape() | op_shape = P.Shape() | ||||
| shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), | shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), | ||||
| @@ -24,13 +24,13 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | |||||
| @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | ||||
| "IndexedSlices", "Tensor", "Tensor", "Bool") | |||||
| "RowTensor", "Tensor", "Tensor", "Bool") | |||||
| def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | ||||
| gradient, weight, moment, ps_parameter): | gradient, weight, moment, ps_parameter): | ||||
| """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | ||||
| success = True | success = True | ||||
| indices = gradient.indices() | |||||
| values = gradient.values() | |||||
| indices = gradient.indices | |||||
| values = gradient.values | |||||
| if ps_parameter: | if ps_parameter: | ||||
| op_shape = P.Shape() | op_shape = P.Shape() | ||||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) | shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) | ||||
| @@ -28,13 +28,13 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") | |||||
| @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", | @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", | ||||
| "IndexedSlices", "Tensor", "Tensor", "Tensor") | |||||
| "RowTensor", "Tensor", "Tensor", "Tensor") | |||||
| def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, | ||||
| moment1, moment2): | moment1, moment2): | ||||
| """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" | """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" | ||||
| success = True | success = True | ||||
| success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | ||||
| eps, gradient.values(), gradient.indices())) | |||||
| eps, gradient.values, gradient.indices)) | |||||
| return success | return success | ||||
| @@ -23,7 +23,7 @@ from mindspore.nn.cell import Cell | |||||
| from mindspore.nn.layer.container import CellList | from mindspore.nn.layer.container import CellList | ||||
| from mindspore.common.parameter import Parameter, ParameterTuple | from mindspore.common.parameter import Parameter, ParameterTuple | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.tensor import Tensor, IndexedSlices | |||||
| from mindspore.common.tensor import Tensor, RowTensor | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| @@ -493,14 +493,14 @@ op_gather = P.GatherV2() | |||||
| _apply_decay = C.MultitypeFuncGraph("apply_decay") | _apply_decay = C.MultitypeFuncGraph("apply_decay") | ||||
| @_apply_decay.register("Number", "Bool", "Tensor", "IndexedSlices") | |||||
| @_apply_decay.register("Number", "Bool", "Tensor", "RowTensor") | |||||
| def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): | def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): | ||||
| """Get grad with weight_decay.""" | """Get grad with weight_decay.""" | ||||
| if if_apply: | if if_apply: | ||||
| indices = gradient.indices() | |||||
| values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values())) | |||||
| shape = gradient.dense_shape() | |||||
| return IndexedSlices(indices, values, shape) | |||||
| indices = gradient.indices | |||||
| values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values)) | |||||
| shape = gradient.dense_shape | |||||
| return RowTensor(indices, values, shape) | |||||
| return gradient | return gradient | ||||
| @@ -523,12 +523,12 @@ def tensor_grad_scale(scale, grad): | |||||
| return grad * scale | return grad * scale | ||||
| @_grad_scale.register("Number", "IndexedSlices") | |||||
| @_grad_scale.register("Number", "RowTensor") | |||||
| def tensor_grad_scale_with_sparse(scale, grad): | def tensor_grad_scale_with_sparse(scale, grad): | ||||
| """Get grad with scale.""" | """Get grad with scale.""" | ||||
| if scale == 1.0: | if scale == 1.0: | ||||
| return grad | return grad | ||||
| return IndexedSlices(grad.indices(), grad.values() * scale, grad.dense_shape()) | |||||
| return RowTensor(grad.indices, grad.values * scale, grad.dense_shape) | |||||
| class _ConvertToCell(LearningRateSchedule): | class _ConvertToCell(LearningRateSchedule): | ||||
| @@ -22,12 +22,12 @@ from .optimizer import Optimizer | |||||
| _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") | _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") | ||||
| @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", | |||||
| @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", | |||||
| "Tensor") | "Tensor") | ||||
| def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): | def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): | ||||
| """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" | """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" | ||||
| success = True | success = True | ||||
| success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices())) | |||||
| success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values, gradient.indices)) | |||||
| return success | return success | ||||
| @@ -49,6 +49,6 @@ class SparseToDense(Cell): | |||||
| self.sparse_to_dense = P.SparseToDense() | self.sparse_to_dense = P.SparseToDense() | ||||
| def construct(self, sparse_tensor): | def construct(self, sparse_tensor): | ||||
| return self.sparse_to_dense(sparse_tensor.indices(), | |||||
| sparse_tensor.values(), | |||||
| sparse_tensor.dense_shape()) | |||||
| return self.sparse_to_dense(sparse_tensor.indices, | |||||
| sparse_tensor.values, | |||||
| sparse_tensor.dense_shape) | |||||
| @@ -16,7 +16,7 @@ | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.communication.management import GlobalComm, get_group_size | from mindspore.communication.management import GlobalComm, get_group_size | ||||
| from mindspore.common.tensor import IndexedSlices | |||||
| from mindspore.common.tensor import RowTensor | |||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore.ops.operations.comm_ops import AllReduce, AllGather | from mindspore.ops.operations.comm_ops import AllReduce, AllGather | ||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | from mindspore.parallel._auto_parallel_context import auto_parallel_context | ||||
| @@ -103,7 +103,7 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, | |||||
| return grad | return grad | ||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") | |||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor") | |||||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): | def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): | ||||
| """ | """ | ||||
| Apply allgather on gradient instead of allreduce for sparse feature. | Apply allgather on gradient instead of allreduce for sparse feature. | ||||
| @@ -118,21 +118,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce | |||||
| grad (tuple): The indices, gradient tensor and tensor_shape before operation. | grad (tuple): The indices, gradient tensor and tensor_shape before operation. | ||||
| Returns: | Returns: | ||||
| IndexedSlices, the gradient after operation. | |||||
| RowTensor, the gradient after operation. | |||||
| """ | """ | ||||
| if allreduce_filter: | if allreduce_filter: | ||||
| indices = allgather(grad.indices()) | |||||
| dout = allgather(grad.values()) | |||||
| indices = allgather(grad.indices) | |||||
| dout = allgather(grad.values) | |||||
| if mean: | if mean: | ||||
| degree = F.scalar_cast(degree, F.dtype(grad.values())) | |||||
| degree = F.scalar_cast(degree, F.dtype(grad.values)) | |||||
| cast_op = P.Cast() | cast_op = P.Cast() | ||||
| mul_op = P.Mul() | mul_op = P.Mul() | ||||
| dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | ||||
| grad = IndexedSlices(indices, dout, grad.dense_shape()) | |||||
| grad = RowTensor(indices, dout, grad.dense_shape) | |||||
| return grad | return grad | ||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") | |||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool") | |||||
| def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | ||||
| """ | """ | ||||
| Apply allgather on gradient instead of allreduce for sparse feature. | Apply allgather on gradient instead of allreduce for sparse feature. | ||||
| @@ -148,20 +148,20 @@ def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allred | |||||
| ps_parameter (bool): Use parameter server or not. | ps_parameter (bool): Use parameter server or not. | ||||
| Returns: | Returns: | ||||
| IndexedSlices, the gradient after operation. | |||||
| RowTensor, the gradient after operation. | |||||
| """ | """ | ||||
| if ps_parameter: | if ps_parameter: | ||||
| return grad | return grad | ||||
| if allreduce_filter: | if allreduce_filter: | ||||
| indices = allgather(grad.indices()) | |||||
| dout = allgather(grad.values()) | |||||
| indices = allgather(grad.indices) | |||||
| dout = allgather(grad.values) | |||||
| if mean: | if mean: | ||||
| degree = F.scalar_cast(degree, F.dtype(grad.values())) | |||||
| degree = F.scalar_cast(degree, F.dtype(grad.values)) | |||||
| cast_op = P.Cast() | cast_op = P.Cast() | ||||
| mul_op = P.Mul() | mul_op = P.Mul() | ||||
| dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | ||||
| grad = IndexedSlices(indices, dout, grad.dense_shape()) | |||||
| grad = RowTensor(indices, dout, grad.dense_shape) | |||||
| return grad | return grad | ||||
| @@ -182,18 +182,18 @@ def _tensors_get_datatype(grad): | |||||
| return F.dtype(grad) | return F.dtype(grad) | ||||
| @_get_datatype.register("IndexedSlices") | |||||
| @_get_datatype.register("RowTensor") | |||||
| def _tensors_get_datatype_with_sparse(grad): | def _tensors_get_datatype_with_sparse(grad): | ||||
| """ | """ | ||||
| Acquire gradient datatype. | Acquire gradient datatype. | ||||
| Args: | Args: | ||||
| grad (IndexedSlices): The gradient before operation. | |||||
| grad (RowTensor): The gradient before operation. | |||||
| Returns: | Returns: | ||||
| mstype, the datatype of gradient. | mstype, the datatype of gradient. | ||||
| """ | """ | ||||
| return F.dtype(grad.values()) | |||||
| return F.dtype(grad.values) | |||||
| _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | ||||
| @@ -214,20 +214,20 @@ def _tensors_cast_datatype(datatype, grad): | |||||
| return F.cast(grad, datatype) | return F.cast(grad, datatype) | ||||
| @_cast_datatype.register("TypeType", "IndexedSlices") | |||||
| @_cast_datatype.register("TypeType", "RowTensor") | |||||
| def _tensors_cast_datatype_with_sparse(datatype, grad): | def _tensors_cast_datatype_with_sparse(datatype, grad): | ||||
| """ | """ | ||||
| Cast gradient to datatype. | Cast gradient to datatype. | ||||
| Args: | Args: | ||||
| datatype (mstype): the destination datatype of gradient. | datatype (mstype): the destination datatype of gradient. | ||||
| grad (IndexedSlices): The gradient before operation. | |||||
| grad (RowTensor): The gradient before operation. | |||||
| Returns: | Returns: | ||||
| IndexedSlices, the gradient after operation. | |||||
| RowTensor, the gradient after operation. | |||||
| """ | """ | ||||
| dout = F.cast(grad.values(), datatype) | |||||
| return IndexedSlices(grad.indices(), dout, grad.dense_shape()) | |||||
| dout = F.cast(grad.values, datatype) | |||||
| return RowTensor(grad.indices, dout, grad.dense_shape) | |||||
| class DistributedGradReducer(Cell): | class DistributedGradReducer(Cell): | ||||
| @@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean | from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from ...common import Tensor, IndexedSlices | |||||
| from ...common import Tensor, RowTensor | |||||
| from ...common.parameter import Parameter | from ...common.parameter import Parameter | ||||
| from ...ops import functional as F | from ...ops import functional as F | ||||
| from ...ops import composite as C | from ...ops import composite as C | ||||
| @@ -35,11 +35,11 @@ reciprocal = P.Reciprocal() | |||||
| def tensor_grad_scale(scale, grad): | def tensor_grad_scale(scale, grad): | ||||
| return grad * F.cast(reciprocal(scale), F.dtype(grad)) | return grad * F.cast(reciprocal(scale), F.dtype(grad)) | ||||
| @_grad_scale.register("Tensor", "IndexedSlices") | |||||
| def tensor_grad_scale_indexed_slices(scale, grad): | |||||
| return IndexedSlices(grad.indices(), | |||||
| grad.values() * F.cast(reciprocal(scale), F.dtype(grad.values())), | |||||
| grad.dense_shape()) | |||||
| @_grad_scale.register("Tensor", "RowTensor") | |||||
| def tensor_grad_scale_row_tensor(scale, grad): | |||||
| return RowTensor(grad.indices, | |||||
| grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), | |||||
| grad.dense_shape) | |||||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | ||||
| grad_overflow = P.FloatStatus() | grad_overflow = P.FloatStatus() | ||||
| @@ -27,7 +27,7 @@ from .grad_base import bprop_getters | |||||
| from ..primitive import constexpr | from ..primitive import constexpr | ||||
| from ... import context | from ... import context | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import IndexedSlices | |||||
| from ...common.tensor import RowTensor | |||||
| reduce_sum = P.ReduceSum() | reduce_sum = P.ReduceSum() | ||||
| unsorted_segment_sum = P.UnsortedSegmentSum() | unsorted_segment_sum = P.UnsortedSegmentSum() | ||||
| @@ -75,12 +75,12 @@ def dout_cast_number(dout, x): | |||||
| dx = cast(dout, get_dtype(x)) | dx = cast(dout, get_dtype(x)) | ||||
| return dx | return dx | ||||
| @dout_cast.register("IndexedSlices", "Tensor") | |||||
| def dout_cast_indexed_slices(dout, x): | |||||
| @dout_cast.register("RowTensor", "Tensor") | |||||
| def dout_cast_row_tensor(dout, x): | |||||
| cast = P.Cast() | cast = P.Cast() | ||||
| get_dtype = P.DType() | get_dtype = P.DType() | ||||
| values = cast(dout.values(), get_dtype(x)) | |||||
| return IndexedSlices(dout.indices(), values, dout.dense_shape()) | |||||
| values = cast(dout.values, get_dtype(x)) | |||||
| return RowTensor(dout.indices, values, dout.dense_shape) | |||||
| @bprop_getters.register(P.Cast) | @bprop_getters.register(P.Cast) | ||||
| @@ -240,7 +240,7 @@ def get_bprop_embedding_lookup(self): | |||||
| actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail | actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail | ||||
| # Reshape the 'actual_dout' on device | # Reshape the 'actual_dout' on device | ||||
| actual_dout = reshape_op(dout, actual_dout_shape_changed) | actual_dout = reshape_op(dout, actual_dout_shape_changed) | ||||
| return IndexedSlices(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) | |||||
| return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) | |||||
| return bprop_sparse | return bprop_sparse | ||||
| @@ -369,7 +369,7 @@ def get_bprop_sparse_gather_v2(self): | |||||
| values_shape = indices_size + x_tail_shp | values_shape = indices_size + x_tail_shp | ||||
| values = reshape(dout, values_shape) | values = reshape(dout, values_shape) | ||||
| indices = reshape(indices, indices_size) | indices = reshape(indices, indices_size) | ||||
| return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||||
| return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||||
| if F.rank(dout) == 0: | if F.rank(dout) == 0: | ||||
| dout = P.ExpandDims()(dout, -1) | dout = P.ExpandDims()(dout, -1) | ||||
| if F.rank(indices) == 0: | if F.rank(indices) == 0: | ||||
| @@ -17,7 +17,7 @@ | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from .. import operations as P | from .. import operations as P | ||||
| from ...common.tensor import IndexedSlices | |||||
| from ...common.tensor import RowTensor | |||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | ||||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | _GetTensorSlice, _MirrorOperator, ReduceOp, | ||||
| @@ -47,9 +47,9 @@ def get_bprop_all_reduce(self): | |||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | if F.issubclass_(F.typeof(dout), mstype.tensor): | ||||
| dx = all_reduce_grad(dout) | dx = all_reduce_grad(dout) | ||||
| else: | else: | ||||
| indices = all_gather(dout.indices()) | |||||
| grad = all_gather(dout.values()) | |||||
| dx = IndexedSlices(indices, grad, dout.dense_shape()) | |||||
| indices = all_gather(dout.indices) | |||||
| grad = all_gather(dout.values) | |||||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||||
| return (dx,) | return (dx,) | ||||
| else: | else: | ||||
| @@ -60,12 +60,12 @@ def get_bprop_all_reduce(self): | |||||
| z = cast(z, dtype(dx)) | z = cast(z, dtype(dx)) | ||||
| dx = mul(dx, z) | dx = mul(dx, z) | ||||
| else: | else: | ||||
| indices = all_gather(dout.indices()) | |||||
| grad = all_gather(dout.values()) | |||||
| indices = all_gather(dout.indices) | |||||
| grad = all_gather(dout.values) | |||||
| z = equal(x, out) | z = equal(x, out) | ||||
| z = cast(z, dtype(grad)) | z = cast(z, dtype(grad)) | ||||
| grad = mul(grad, z) | grad = mul(grad, z) | ||||
| dx = IndexedSlices(indices, grad, dout.dense_shape()) | |||||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -195,19 +195,19 @@ def get_bprop_mirror_operator(self): | |||||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | num = F.scalar_cast(dev_num, F.dtype(dx)) | ||||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | ||||
| else: | else: | ||||
| indices = all_gather(dout.indices()) | |||||
| grad = all_gather(dout.values()) | |||||
| indices = all_gather(dout.indices) | |||||
| grad = all_gather(dout.values) | |||||
| float_one = F.scalar_cast(1.0, F.dtype(grad)) | float_one = F.scalar_cast(1.0, F.dtype(grad)) | ||||
| num = F.scalar_cast(dev_num, F.dtype(grad)) | num = F.scalar_cast(dev_num, F.dtype(grad)) | ||||
| grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) | grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) | ||||
| dx = IndexedSlices(indices, grad, dout.dense_shape()) | |||||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||||
| else: | else: | ||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | if F.issubclass_(F.typeof(dout), mstype.tensor): | ||||
| dx = all_reduce(dout) | dx = all_reduce(dout) | ||||
| else: | else: | ||||
| indices = all_gather(dout.indices()) | |||||
| grad = all_gather(dout.values()) | |||||
| dx = IndexedSlices(indices, grad, dout.dense_shape()) | |||||
| indices = all_gather(dout.indices) | |||||
| grad = all_gather(dout.values) | |||||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -152,10 +152,10 @@ shape_mul = Primitive("shape_mul") | |||||
| # a primitive to compare between tuple. | # a primitive to compare between tuple. | ||||
| stop_gradient = Primitive("stop_gradient") | stop_gradient = Primitive("stop_gradient") | ||||
| make_indexed_slices = Primitive('MakeIndexedSlices') | |||||
| indexed_slices_get_values = Primitive('IndexedSlicesGetValues') | |||||
| indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') | |||||
| indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') | |||||
| make_row_tensor = Primitive('MakeRowTensor') | |||||
| row_tensor_get_values = Primitive('RowTensorGetValues') | |||||
| row_tensor_get_indices = Primitive('RowTensorGetIndices') | |||||
| row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape') | |||||
| make_sparse_tensor = Primitive('MakeSparseTensor') | make_sparse_tensor = Primitive('MakeSparseTensor') | ||||
| sparse_tensor_get_values = Primitive('SparseTensorGetValues') | sparse_tensor_get_values = Primitive('SparseTensorGetValues') | ||||
| @@ -389,8 +389,8 @@ class CheckBprop(PrimitiveWithInfer): | |||||
| validator.check_value_type('grads', xshapes, (tuple,), tips) | validator.check_value_type('grads', xshapes, (tuple,), tips) | ||||
| validator.check_value_type('params', yshapes, (tuple,), tips) | validator.check_value_type('params', yshapes, (tuple,), tips) | ||||
| if len(xshapes) < len(yshapes): | if len(xshapes) < len(yshapes): | ||||
| raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," | |||||
| f" but got {len(xshapes)}.") | |||||
| raise ValueError(f"{tips}, the size of output should be {len(yshapes)}," | |||||
| f" but got {len(xshapes)}.") | |||||
| checking_range = len(yshapes) | checking_range = len(yshapes) | ||||
| for i in range(checking_range): | for i in range(checking_range): | ||||
| xshape = xshapes[i] | xshape = xshapes[i] | ||||
| @@ -398,8 +398,8 @@ class CheckBprop(PrimitiveWithInfer): | |||||
| if not xshape or not yshape: | if not xshape or not yshape: | ||||
| continue | continue | ||||
| if xshape != yshape: | if xshape != yshape: | ||||
| raise TypeError(f"{tips}, the shape of {i}th output should be {yshape}," | |||||
| f" but got {xshape}.") | |||||
| raise ValueError(f"{tips}, the shape of {i}th output should be {yshape}," | |||||
| f" but got {xshape}.") | |||||
| return xshapes | return xshapes | ||||
| def infer_dtype(self, xdtypes, ydtypes): | def infer_dtype(self, xdtypes, ydtypes): | ||||
| @@ -407,8 +407,8 @@ class CheckBprop(PrimitiveWithInfer): | |||||
| validator.check_value_type('grads', xdtypes, (tuple,), tips) | validator.check_value_type('grads', xdtypes, (tuple,), tips) | ||||
| validator.check_value_type('params', ydtypes, (tuple,), tips) | validator.check_value_type('params', ydtypes, (tuple,), tips) | ||||
| if len(xdtypes) < len(ydtypes): | if len(xdtypes) < len(ydtypes): | ||||
| raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," | |||||
| f" but got {len(xdtypes)}.") | |||||
| raise ValueError(f"{tips}, the size of output should be {len(ydtypes)}," | |||||
| f" but got {len(xdtypes)}.") | |||||
| checking_range = len(ydtypes) | checking_range = len(ydtypes) | ||||
| for i in range(checking_range): | for i in range(checking_range): | ||||
| xdtype = xdtypes[i] | xdtype = xdtypes[i] | ||||
| @@ -19,25 +19,16 @@ import pytest | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Parameter | |||||
| from mindspore import Parameter, ParameterTuple | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from .....mindspore_test_framework.utils.bprop_util import bprop | |||||
| def setup_module(module): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| def teardown_module(module): | |||||
| context.set_context(device_target="Ascend") | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class MulAdd(nn.Cell): | class MulAdd(nn.Cell): | ||||
| def __init__(self): | |||||
| super(MulAdd, self).__init__() | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return 2 * x + y | return 2 * x + y | ||||
| @@ -45,7 +36,9 @@ class MulAdd(nn.Cell): | |||||
| # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result | # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result | ||||
| return 2 * dout, 2 * y | return 2 * dout, 2 * y | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_mul_add(): | def test_grad_mul_add(): | ||||
| mul_add = MulAdd() | mul_add = MulAdd() | ||||
| x = Tensor(1, dtype=ms.int32) | x = Tensor(1, dtype=ms.int32) | ||||
| @@ -62,7 +55,9 @@ class InlineMulADD(nn.Cell): | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return self.mul_add(x, y) + x + self.param * y | return self.mul_add(x, y) + x + self.param * y | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_inline_mul_add(): | def test_grad_inline_mul_add(): | ||||
| inline_mul_add = InlineMulADD() | inline_mul_add = InlineMulADD() | ||||
| x = Tensor(1, dtype=ms.int32) | x = Tensor(1, dtype=ms.int32) | ||||
| @@ -83,7 +78,9 @@ class WithParameter(nn.Cell): | |||||
| # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result | # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result | ||||
| return self.param1 * self.param2 * dout, 2 * y | return self.param1 * self.param2 * dout, 2 * y | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_with_param(): | def test_with_param(): | ||||
| with_param = WithParameter() | with_param = WithParameter() | ||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| @@ -91,20 +88,21 @@ def test_with_param(): | |||||
| class WithNoBprop(nn.Cell): | class WithNoBprop(nn.Cell): | ||||
| def __init__(self): | |||||
| super(WithNoBprop, self).__init__() | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return 2 * x + y | return 2 * x + y | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_with_no_bprop(): | def test_with_no_bprop(): | ||||
| with_no_bprop = WithNoBprop() | with_no_bprop = WithNoBprop() | ||||
| x = Tensor(1, dtype=ms.int32) | x = Tensor(1, dtype=ms.int32) | ||||
| y = Tensor(2, dtype=ms.int32) | y = Tensor(2, dtype=ms.int32) | ||||
| C.grad_all(with_no_bprop)(x, y) | |||||
| assert C.grad_all(with_no_bprop)(x, y) == (2, 1) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_in_bprop_1(): | def test_grad_in_bprop_1(): | ||||
| class GradInBprop_1(nn.Cell): | class GradInBprop_1(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -140,7 +138,9 @@ def test_grad_in_bprop_1(): | |||||
| assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() | assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() | ||||
| assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() | assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_in_bprop_2(): | def test_grad_in_bprop_2(): | ||||
| class GradInBprop_1(nn.Cell): | class GradInBprop_1(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -179,7 +179,9 @@ def test_grad_in_bprop_2(): | |||||
| assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() | assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() | ||||
| assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() | assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_in_bprop_3(): | def test_grad_in_bprop_3(): | ||||
| class GradInBprop_1(nn.Cell): | class GradInBprop_1(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -230,7 +232,9 @@ class OneInputBprop(nn.Cell): | |||||
| def bprop(self, x, out, dout): | def bprop(self, x, out, dout): | ||||
| return (5 * x,) | return (5 * x,) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_one_input_bprop(): | def test_grad_one_input_bprop(): | ||||
| net = OneInputBprop() | net = OneInputBprop() | ||||
| input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | ||||
| @@ -239,9 +243,6 @@ def test_grad_one_input_bprop(): | |||||
| class TwoInput(nn.Cell): | class TwoInput(nn.Cell): | ||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return x * y | return x * y | ||||
| @@ -258,12 +259,17 @@ class InlineBpropTwoInput(nn.Cell): | |||||
| grads = C.grad_all(self.f)(x, y) | grads = C.grad_all(self.f)(x, y) | ||||
| return grads[0] * 2, grads[1] * 2 | return grads[0] * 2, grads[1] * 2 | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_inline_bprop_two_input(): | def test_grad_inline_bprop_two_input(): | ||||
| net = InlineBpropTwoInput() | net = InlineBpropTwoInput() | ||||
| input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | ||||
| input2 = Tensor(np.ones([2, 2]).astype(np.float32)) | input2 = Tensor(np.ones([2, 2]).astype(np.float32)) | ||||
| C.grad_all(net)(input1, input2) | |||||
| grads = C.grad_all(net)(input1, input2) | |||||
| assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() | |||||
| assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() | |||||
| assert len(grads) == 2 | |||||
| class TwoInputBprop(nn.Cell): | class TwoInputBprop(nn.Cell): | ||||
| @@ -314,7 +320,9 @@ class InlineMutilTwoInputParameterCell(nn.Cell): | |||||
| output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) | output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) | ||||
| return output | return output | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_inline_bprop_multi_input(): | def test_grad_inline_bprop_multi_input(): | ||||
| net = InlineMutilTwoInputParameterCell() | net = InlineMutilTwoInputParameterCell() | ||||
| input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | ||||
| @@ -335,29 +343,54 @@ class MulAddWithParam(nn.Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return self.mul_add(self.param, x) | return self.mul_add(self.param, x) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_refkey_bprop(): | def test_refkey_bprop(): | ||||
| net = MulAddWithParam() | |||||
| grad_by_list = C.GradOperation('get_by_list', get_all=True, get_by_list=True) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) | |||||
| def construct(self, x): | |||||
| weights = self.weights | |||||
| grads = grad_by_list(self.network, weights)(x) | |||||
| return grads | |||||
| network = GradWrap(MulAddWithParam()) | |||||
| input_data = Tensor(np.array([2, 2], np.float32)) | input_data = Tensor(np.array([2, 2], np.float32)) | ||||
| grads = bprop(net, input_data, | |||||
| grads_wrt_outputs=(Tensor(np.ones([1, 2]).astype(np.float32))), | |||||
| wrt=['params', 'inputs'], | |||||
| params=net.trainable_params()) | |||||
| grads = network(input_data) | |||||
| assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() | assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() | ||||
| assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() | assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() | ||||
| class MulAddWithWrongOutputType(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MulAddWithWrongOutputType, self).__init__() | |||||
| class MulAddWithWrongOutputNum(nn.Cell): | |||||
| def construct(self, x, y): | |||||
| return 2 * x + y | |||||
| def bprop(self, x, y, out, dout): | |||||
| return (2 * dout,) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_mul_add_with_wrong_output_num(): | |||||
| context.set_context(check_bprop=True) | |||||
| mul_add = MulAddWithWrongOutputNum() | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(mul_add)(1, 2) | |||||
| class MulAddWithWrongOutputType(nn.Cell): | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return 2 * x + y | return 2 * x + y | ||||
| def bprop(self, x, y, out, dout): | def bprop(self, x, y, out, dout): | ||||
| return 2 * dout, 2 | return 2 * dout, 2 | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_mul_add_with_wrong_output_type(): | def test_grad_mul_add_with_wrong_output_type(): | ||||
| context.set_context(check_bprop=True) | context.set_context(check_bprop=True) | ||||
| mul_add = MulAddWithWrongOutputType() | mul_add = MulAddWithWrongOutputType() | ||||
| @@ -376,7 +409,9 @@ class MulAddWithWrongOutputShape(nn.Cell): | |||||
| def bprop(self, x, y, out, dout): | def bprop(self, x, y, out, dout): | ||||
| return 2, self.ones | return 2, self.ones | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_grad_mul_add_with_wrong_output_shape(): | def test_grad_mul_add_with_wrong_output_shape(): | ||||
| context.set_context(check_bprop=True) | context.set_context(check_bprop=True) | ||||
| mul_add = MulAddWithWrongOutputShape() | mul_add = MulAddWithWrongOutputShape() | ||||
| @@ -606,14 +606,14 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { | |||||
| ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | ||||
| } | } | ||||
| TEST_F(TestOptLib, test_indexed_slices) { | |||||
| FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices"); | |||||
| FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices"); | |||||
| FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values"); | |||||
| FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values"); | |||||
| FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape"); | |||||
| FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape"); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.indexed_slices_eliminate_}); | |||||
| TEST_F(TestOptLib, test_row_tensor) { | |||||
| FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "before_get_indices"); | |||||
| FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "after_get_indices"); | |||||
| FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_row_tensor", "before_get_values"); | |||||
| FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_row_tensor", "after_get_values"); | |||||
| FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "before_get_dense_shape"); | |||||
| FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "after_get_dense_shape"); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.row_tensor_eliminate_}); | |||||
| ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); | ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); | ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); | ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); | ||||
| @@ -1130,17 +1130,17 @@ def test_adjust_allreduce_mul_add(tag): | |||||
| return fns[tag] | return fns[tag] | ||||
| def test_indexed_slices(tag): | |||||
| def test_row_tensor(tag): | |||||
| """ test_add_zero """ | """ test_add_zero """ | ||||
| fns = FnDict() | fns = FnDict() | ||||
| make_indexed_slices = Primitive('MakeIndexedSlices') | |||||
| indexed_slices_get_values = Primitive('IndexedSlicesGetValues') | |||||
| indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') | |||||
| indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') | |||||
| make_row_tensor = Primitive('MakeRowTensor') | |||||
| row_tensor_get_values = Primitive('RowTensorGetValues') | |||||
| row_tensor_get_indices = Primitive('RowTensorGetIndices') | |||||
| row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape') | |||||
| @fns | @fns | ||||
| def before_get_indices(x, y, z): | def before_get_indices(x, y, z): | ||||
| return indexed_slices_get_indices(make_indexed_slices(x, y, z)) | |||||
| return row_tensor_get_indices(make_row_tensor(x, y, z)) | |||||
| @fns | @fns | ||||
| def after_get_indices(x, y, z): | def after_get_indices(x, y, z): | ||||
| @@ -1148,7 +1148,7 @@ def test_indexed_slices(tag): | |||||
| @fns | @fns | ||||
| def before_get_values(x, y, z): | def before_get_values(x, y, z): | ||||
| return indexed_slices_get_values(make_indexed_slices(x, y, z)) | |||||
| return row_tensor_get_values(make_row_tensor(x, y, z)) | |||||
| @fns | @fns | ||||
| def after_get_values(x, y, z): | def after_get_values(x, y, z): | ||||
| @@ -1156,7 +1156,7 @@ def test_indexed_slices(tag): | |||||
| @fns | @fns | ||||
| def before_get_dense_shape(x, y, z): | def before_get_dense_shape(x, y, z): | ||||
| return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z)) | |||||
| return row_tensor_get_dense_shape(make_row_tensor(x, y, z)) | |||||
| @fns | @fns | ||||
| def after_get_dense_shape(x, y, z): | def after_get_dense_shape(x, y, z): | ||||
| @@ -13,10 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | """ | ||||
| @File : test_indexed_slices.py | |||||
| @File : test_row_tensor.py | |||||
| @Author: | @Author: | ||||
| @Date : 2020-06-08 | @Date : 2020-06-08 | ||||
| @Desc : test mindspore indexed_slices's operation | |||||
| @Desc : test mindspore row_tensor's operation | |||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| @@ -29,7 +29,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like | from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from mindspore.ops.primitive import constexpr | from mindspore.ops.primitive import constexpr | ||||
| from mindspore.ops._grad.grad_base import bprop_getters | from mindspore.ops._grad.grad_base import bprop_getters | ||||
| from mindspore import Tensor, IndexedSlices, context | |||||
| from mindspore import Tensor, RowTensor, context | |||||
| from mindspore.common.parameter import Parameter, ParameterTuple | from mindspore.common.parameter import Parameter, ParameterTuple | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| @@ -122,7 +122,7 @@ def get_bprop_sparse_gather_v2(self): | |||||
| values_shape = indices_size + x_tail_shp | values_shape = indices_size + x_tail_shp | ||||
| values = reshape(dout, values_shape) | values = reshape(dout, values_shape) | ||||
| indices = reshape(indices, indices_size) | indices = reshape(indices, indices_size) | ||||
| return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||||
| return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||||
| if F.rank(dout) == 0: | if F.rank(dout) == 0: | ||||
| dout = P.ExpandDims()(dout, -1) | dout = P.ExpandDims()(dout, -1) | ||||
| if F.rank(indices) == 0: | if F.rank(indices) == 0: | ||||
| @@ -142,10 +142,10 @@ def get_bprop_sparse_gather_v2(self): | |||||
| adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | ||||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | ||||
| "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool") | |||||
| def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param, | |||||
| m, v, gradient, decay_flag): | |||||
| return gradient.values() | |||||
| "Tensor", "Tensor", "Tensor", "RowTensor", "Bool") | |||||
| def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, | |||||
| m, v, gradient, decay_flag): | |||||
| return gradient.values | |||||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | ||||
| "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | ||||
| @@ -219,35 +219,35 @@ class AdamWeightDecaySparse(Optimizer): | |||||
| return updated_velocity | return updated_velocity | ||||
| def test_indexed_slices_make_indexed_slices(): | |||||
| class MakeIndexedSlices(nn.Cell): | |||||
| def test_row_tensor_make_row_tensor(): | |||||
| class MakeRowTensor(nn.Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(MakeIndexedSlices, self).__init__() | |||||
| super(MakeRowTensor, self).__init__() | |||||
| self.dense_shape = (3, 2) | self.dense_shape = (3, 2) | ||||
| def construct(self, indices, values): | def construct(self, indices, values): | ||||
| ret = (IndexedSlices(indices, values, self.dense_shape),) | |||||
| ret = (RowTensor(indices, values, self.dense_shape),) | |||||
| return ret[0] | return ret[0] | ||||
| indices = Tensor([1, 2]) | indices = Tensor([1, 2]) | ||||
| values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) | values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) | ||||
| MakeIndexedSlices()(indices, values) | |||||
| MakeRowTensor()(indices, values) | |||||
| class IndexedSlicesGetAttr(nn.Cell): | |||||
| class RowTensorGetAttr(nn.Cell): | |||||
| def __init__(self, dense_shape): | def __init__(self, dense_shape): | ||||
| super(IndexedSlicesGetAttr, self).__init__() | |||||
| super(RowTensorGetAttr, self).__init__() | |||||
| self.dense_shape = dense_shape | self.dense_shape = dense_shape | ||||
| def construct(self, indices, values): | def construct(self, indices, values): | ||||
| x = IndexedSlices(indices, values, self.dense_shape) | |||||
| return x.values(), x.indices(), x.dense_shape() | |||||
| x = RowTensor(indices, values, self.dense_shape) | |||||
| return x.values, x.indices, x.dense_shape | |||||
| def test_indexed_slices_attr(): | |||||
| def test_row_tensor_attr(): | |||||
| indices = Tensor([0]) | indices = Tensor([0]) | ||||
| values = Tensor([[1, 2]], dtype=ms.float32) | values = Tensor([[1, 2]], dtype=ms.float32) | ||||
| IndexedSlicesGetAttr((3, 2))(indices, values) | |||||
| RowTensorGetAttr((3, 2))(indices, values) | |||||
| def test_indexed_slices_sparse_gatherv2_grad_all(): | |||||
| def test_row_tensor_sparse_gatherv2_grad_all(): | |||||
| grad_all = C.GradOperation('get_all', get_all=True) | grad_all = C.GradOperation('get_all', get_all=True) | ||||
| class GradWrap(nn.Cell): | class GradWrap(nn.Cell): | ||||
| def __init__(self, network): | def __init__(self, network): | ||||
| @@ -255,7 +255,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): | |||||
| self.network = network | self.network = network | ||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| grad = grad_all(self.network)(x, y) | grad = grad_all(self.network)(x, y) | ||||
| return grad[0].indices(), grad[0].values(), grad[0].dense_shape() | |||||
| return grad[0].indices, grad[0].values, grad[0].dense_shape | |||||
| class SparseGatherV2(nn.Cell): | class SparseGatherV2(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(SparseGatherV2, self).__init__() | super(SparseGatherV2, self).__init__() | ||||
| @@ -268,7 +268,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): | |||||
| GradWrap(SparseGatherV2())(params, indices) | GradWrap(SparseGatherV2())(params, indices) | ||||
| def test_indexed_slices_sparse_gatherv2_grad_with_pram(): | |||||
| def test_row_tensor_sparse_gatherv2_grad_with_pram(): | |||||
| grad_by_list = C.GradOperation('get_by_list', get_by_list=True) | grad_by_list = C.GradOperation('get_by_list', get_by_list=True) | ||||
| class GradWrap(nn.Cell): | class GradWrap(nn.Cell): | ||||
| def __init__(self, network): | def __init__(self, network): | ||||
| @@ -279,7 +279,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): | |||||
| weights = self.weights | weights = self.weights | ||||
| grad = grad_by_list(self.network, weights)(x) | grad = grad_by_list(self.network, weights)(x) | ||||
| x = grad[0] | x = grad[0] | ||||
| return x.values(), x.indices(), x.dense_shape() | |||||
| return x.values, x.indices, x.dense_shape | |||||
| class SparseGatherV2(nn.Cell): | class SparseGatherV2(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(SparseGatherV2, self).__init__() | super(SparseGatherV2, self).__init__() | ||||
| @@ -293,7 +293,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): | |||||
| network(indices) | network(indices) | ||||
| def test_indexed_slices_env_get(): | |||||
| def test_row_tensor_env_get(): | |||||
| class Loss(nn.Cell): | class Loss(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Loss, self).__init__() | super(Loss, self).__init__() | ||||
| @@ -321,7 +321,7 @@ def test_indexed_slices_env_get(): | |||||
| train_network(inputs, label) | train_network(inputs, label) | ||||
| def test_indexed_slices_model_train(): | |||||
| def test_row_tensor_model_train(): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, in_features, out_features): | def __init__(self, in_features, out_features): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| @@ -347,76 +347,76 @@ def test_indexed_slices_model_train(): | |||||
| model.train(2, dataset, dataset_sink_mode=False) | model.train(2, dataset, dataset_sink_mode=False) | ||||
| def test_indexed_slices_values_dim_greater_than_dense_shape_dim(): | |||||
| def test_row_tensor_values_dim_greater_than_dense_shape_dim(): | |||||
| indices = Tensor(np.array([0, 1], dtype=np.int32)) | indices = Tensor(np.array([0, 1], dtype=np.int32)) | ||||
| values = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) | values = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) | ||||
| dense_shape = (3, 4) | dense_shape = (3, 4) | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| IndexedSlicesGetAttr(dense_shape)(indices, values) | |||||
| RowTensorGetAttr(dense_shape)(indices, values) | |||||
| def test_indexed_slices_values_dim_less_than_dense_shape_dim(): | |||||
| def test_row_tensor_values_dim_less_than_dense_shape_dim(): | |||||
| indices = Tensor(np.array([0, 1], dtype=np.int32)) | indices = Tensor(np.array([0, 1], dtype=np.int32)) | ||||
| values = Tensor(np.random.randn(2, 4).astype(np.float32)) | values = Tensor(np.random.randn(2, 4).astype(np.float32)) | ||||
| dense_shape = (3, 4, 5) | dense_shape = (3, 4, 5) | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| IndexedSlicesGetAttr(dense_shape)(indices, values) | |||||
| RowTensorGetAttr(dense_shape)(indices, values) | |||||
| def test_indexed_slices_value_and_dense_shape_illegal(): | |||||
| def test_row_tensor_value_and_dense_shape_illegal(): | |||||
| indices = Tensor(np.array([0, 1], dtype=np.int32)) | indices = Tensor(np.array([0, 1], dtype=np.int32)) | ||||
| values = Tensor(np.random.randn(2, 4).astype(np.float32)) | values = Tensor(np.random.randn(2, 4).astype(np.float32)) | ||||
| dense_shape = (3, 5) | dense_shape = (3, 5) | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| IndexedSlicesGetAttr(dense_shape)(indices, values) | |||||
| RowTensorGetAttr(dense_shape)(indices, values) | |||||
| class IndexedSlicesValuesDouble(nn.Cell): | |||||
| class RowTensorValuesDouble(nn.Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| indices = x.indices() | |||||
| values = x.values() * 2 | |||||
| dense_shape = x.dense_shape() | |||||
| return IndexedSlices(indices, values, dense_shape) | |||||
| indices = x.indices | |||||
| values = x.values * 2 | |||||
| dense_shape = x.dense_shape | |||||
| return RowTensor(indices, values, dense_shape) | |||||
| class IndexedSlicesValuesAdd2(nn.Cell): | |||||
| class RowTensorValuesAdd2(nn.Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| indices = x.indices() | |||||
| values = x.values() + 2 | |||||
| dense_shape = x.dense_shape() | |||||
| return IndexedSlices(indices, values, dense_shape) | |||||
| indices = x.indices | |||||
| values = x.values + 2 | |||||
| dense_shape = x.dense_shape | |||||
| return RowTensor(indices, values, dense_shape) | |||||
| class IndexedSlicesWithControlIf(nn.Cell): | |||||
| class RowTensorWithControlIf(nn.Cell): | |||||
| def __init__(self, dense_shape): | def __init__(self, dense_shape): | ||||
| super().__init__() | super().__init__() | ||||
| self.op1 = IndexedSlicesValuesDouble() | |||||
| self.op2 = IndexedSlicesValuesAdd2() | |||||
| self.op1 = RowTensorValuesDouble() | |||||
| self.op2 = RowTensorValuesAdd2() | |||||
| self.dense_shape = dense_shape | self.dense_shape = dense_shape | ||||
| def construct(self, a, b, indices, values): | def construct(self, a, b, indices, values): | ||||
| x = IndexedSlices(indices, values, self.dense_shape) | |||||
| x = RowTensor(indices, values, self.dense_shape) | |||||
| if a > b: | if a > b: | ||||
| x = self.op1(x) | x = self.op1(x) | ||||
| else: | else: | ||||
| x = self.op2(x) | x = self.op2(x) | ||||
| return x.indices(), x.values() | |||||
| return x.indices, x.values | |||||
| def test_indexed_slices_with_control_flow_if(): | |||||
| def test_row_tensor_with_control_flow_if(): | |||||
| a = Tensor(np.array(0).astype(np.int32)) | a = Tensor(np.array(0).astype(np.int32)) | ||||
| b = Tensor(np.array(2).astype(np.int32)) | b = Tensor(np.array(2).astype(np.int32)) | ||||
| indices = Tensor(np.array([0, 2]).astype(np.int32)) | indices = Tensor(np.array([0, 2]).astype(np.int32)) | ||||
| values = Tensor(np.ones([2, 2]).astype(np.float32)) | values = Tensor(np.ones([2, 2]).astype(np.float32)) | ||||
| dense_shape = (5, 2) | dense_shape = (5, 2) | ||||
| net = IndexedSlicesWithControlIf(dense_shape) | |||||
| net = RowTensorWithControlIf(dense_shape) | |||||
| net(a, b, indices, values) | net(a, b, indices, values) | ||||
| @@ -52,7 +52,7 @@ def test_sparse_tensor_attr(): | |||||
| self.dense_shape = (3, 4) | self.dense_shape = (3, 4) | ||||
| def construct(self, indices, values): | def construct(self, indices, values): | ||||
| x = SparseTensor(indices, values, self.dense_shape) | x = SparseTensor(indices, values, self.dense_shape) | ||||
| return x.values(), x.indices(), x.dense_shape() | |||||
| return x.values, x.indices, x.dense_shape | |||||
| indices = Tensor([[0, 1], [1, 2]]) | indices = Tensor([[0, 1], [1, 2]]) | ||||
| values = Tensor([1, 2], dtype=ms.float32) | values = Tensor([1, 2], dtype=ms.float32) | ||||
| @@ -175,7 +175,7 @@ def test_bprop_with_wrong_output_num(): | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return BpropWithWrongOutputNum()(x, y) | return BpropWithWrongOutputNum()(x, y) | ||||
| with pytest.raises(TypeError): | |||||
| with pytest.raises(ValueError): | |||||
| C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | ||||
| def test_bprop_with_wrong_output_type(): | def test_bprop_with_wrong_output_type(): | ||||
| @@ -247,7 +247,7 @@ def test_bprop_with_wrong_output_shape(): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return BpropWithWrongOutputShape()(x) | return BpropWithWrongOutputShape()(x) | ||||
| with pytest.raises(TypeError): | |||||
| with pytest.raises(ValueError): | |||||
| net = BpropWithWrongOutputShapeCell() | net = BpropWithWrongOutputShapeCell() | ||||
| net.set_grad() | net.set_grad() | ||||
| C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) | C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) | ||||
| @@ -20,7 +20,7 @@ | |||||
| """ | """ | ||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context, Tensor, IndexedSlices, SparseTensor | |||||
| from mindspore import context, Tensor, RowTensor, SparseTensor | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) | context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) | ||||
| @@ -36,18 +36,18 @@ class GradWrap(nn.Cell): | |||||
| return grad | return grad | ||||
| def test_indexed_slices_attr(): | |||||
| class IndexedSlicesGetAttr(nn.Cell): | |||||
| def test_row_tensor_attr(): | |||||
| class RowTensorGetAttr(nn.Cell): | |||||
| def __init__(self, dense_shape): | def __init__(self, dense_shape): | ||||
| super(IndexedSlicesGetAttr, self).__init__() | |||||
| super(RowTensorGetAttr, self).__init__() | |||||
| self.dense_shape = dense_shape | self.dense_shape = dense_shape | ||||
| def construct(self, indices, values): | def construct(self, indices, values): | ||||
| x = IndexedSlices(indices, values, self.dense_shape) | |||||
| return x.values(), x.indices(), x.dense_shape() | |||||
| x = RowTensor(indices, values, self.dense_shape) | |||||
| return x.values, x.indices, x.dense_shape | |||||
| indices = Tensor([0]) | indices = Tensor([0]) | ||||
| values = Tensor([[1, 2]], dtype=ms.float32) | values = Tensor([[1, 2]], dtype=ms.float32) | ||||
| IndexedSlicesGetAttr((3, 2))(indices, values) | |||||
| GradWrap(IndexedSlicesGetAttr((3, 2)))(indices, values) | |||||
| RowTensorGetAttr((3, 2))(indices, values) | |||||
| GradWrap(RowTensorGetAttr((3, 2)))(indices, values) | |||||
| def test_sparse_tensor_attr(): | def test_sparse_tensor_attr(): | ||||
| @@ -57,7 +57,7 @@ def test_sparse_tensor_attr(): | |||||
| self.dense_shape = (3, 4) | self.dense_shape = (3, 4) | ||||
| def construct(self, indices, values): | def construct(self, indices, values): | ||||
| x = SparseTensor(indices, values, self.dense_shape) | x = SparseTensor(indices, values, self.dense_shape) | ||||
| return x.values(), x.indices(), x.dense_shape() | |||||
| return x.values, x.indices, x.dense_shape | |||||
| indices = Tensor([[0, 1], [1, 2]]) | indices = Tensor([[0, 1], [1, 2]]) | ||||
| values = Tensor([1, 2], dtype=ms.float32) | values = Tensor([1, 2], dtype=ms.float32) | ||||