Merge pull request !26731 from wangrao124/add_csr_optags/v1.6.0
| @@ -577,7 +577,7 @@ bool AkgKernelBuilder::AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf | |||
| AkgKernelJsonGenerator akg_kernel_json_generator(option); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| bool is_custom_node = IsPrimitiveCNode(cnode, prim::kPrimCustom); | |||
| bool is_custom_node = IsPrimitiveCNode(cnode, prim::kPrimCustom) || IsCustomCSROP(cnode); | |||
| // Graph kernel node and Custom node need to generate composite json | |||
| if (AnfAlgo::IsGraphKernel(cnode) || is_custom_node) { | |||
| FuncGraphPtr func_graph = is_custom_node ? cnode->func_graph() : AnfAlgo::GetCNodeFuncGraphPtr(cnode); | |||
| @@ -25,6 +25,7 @@ | |||
| #include "nlohmann/json.hpp" | |||
| #include "backend/kernel_compiler/oplib/opinfo.h" | |||
| #include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h" | |||
| #include "utils/convert_utils.h" | |||
| namespace mindspore::graphkernel { | |||
| using kernel::OpAttrPtr; | |||
| @@ -54,6 +54,9 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(prim::kPrimReduceAny->name(), {1}); | |||
| Register(prim::kPrimUnsortedSegmentMin->name(), {2}); | |||
| Register(prim::kPrimUnsortedSegmentMax->name(), {2}); | |||
| Register(prim::kPrimCSRMul->name(), {3}); | |||
| Register(prim::kPrimCSRReduceSum->name(), {3, 4}); | |||
| Register(prim::kPrimCSRMV->name(), {3}); | |||
| Register(kSparseGatherV2OpName, {2}); | |||
| Register(kUnsortedSegmentProdOpName, {2}); | |||
| Register(kSimpleMeanGradOpName, {1}); | |||
| @@ -26,6 +26,9 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| using CSRTensor = mindspore::tensor::CSRTensor; | |||
| using CSRTensorPtr = mindspore::tensor::CSRTensorPtr; | |||
| // Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract. | |||
| void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) { | |||
| MS_EXCEPTION_IF_NULL(sparse); | |||
| @@ -44,6 +47,45 @@ void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) { | |||
| } | |||
| } | |||
| ValueNodePtr NewValueNodeAndSetAbstract(const ValuePtr &val, const AbstractBasePtr &abs) { | |||
| auto node = NewValueNode(val); | |||
| node->set_abstract(abs); | |||
| return node; | |||
| } | |||
| bool SplitValueNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) { | |||
| ValuePtr value = node->cast<ValueNodePtr>()->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (!value->isa<CSRTensor>()) return false; | |||
| auto csr_tensor = value->cast<CSRTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(csr_tensor); | |||
| auto csr_abs = node->abstract()->cast<abstract::AbstractCSRTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(csr_abs); | |||
| new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetIndptr(), csr_abs->indptr())); | |||
| new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetIndices(), csr_abs->indices())); | |||
| new_inputs->push_back(NewValueNodeAndSetAbstract(csr_tensor->GetValues(), csr_abs->values())); | |||
| auto shape_node = NewValueNode(csr_tensor->shape()); | |||
| shape_node->set_abstract(csr_abs->dense_shape()); | |||
| new_inputs->push_back(shape_node); | |||
| return true; | |||
| } | |||
| bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto sparse_prim = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(sparse_prim); | |||
| // Currently, only MakeCSR and MakeTuple nodes can be split. | |||
| if (make_sparse_set.count(sparse_prim->name()) <= 0 && sparse_prim->name().compare(prim::kPrimMakeTuple->name()) != 0) | |||
| return false; | |||
| auto sparse_inputs = cnode->inputs(); | |||
| for (size_t j = 1; j < sparse_inputs.size(); ++j) { | |||
| new_inputs->push_back(sparse_inputs[j]); | |||
| } | |||
| return true; | |||
| } | |||
| const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -90,7 +132,25 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An | |||
| auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph); | |||
| new_node->set_abstract(node->abstract()); | |||
| return new_node; | |||
| // ComputeSparse node: SparseTensorDenseMatmul, CSRDenseMul, CSRReduceSum | |||
| } else if (sparse_op_set.find(prim_name) != sparse_op_set.end()) { | |||
| const auto &inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| new_inputs.push_back(inputs[0]); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (inputs[i]->isa<CNode>()) { | |||
| if (SplitCNode(inputs[i], &new_inputs)) continue; | |||
| } else if (inputs[i]->isa<ValueNode>()) { | |||
| if (SplitValueNode(inputs[i], &new_inputs)) continue; | |||
| } | |||
| new_inputs.push_back(inputs[i]); | |||
| } | |||
| auto new_node = cnode->func_graph()->NewCNode(new_inputs); | |||
| new_node->set_abstract(node->abstract()); | |||
| AnfAlgo::SetNodeAttr("is_csr", MakeValue(true), new_node); | |||
| return new_node; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace opt | |||
| @@ -908,6 +908,14 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const TypePtr &type, size_t o | |||
| return elem->type_id(); | |||
| } | |||
| if (type_ptr->isa<CSRTensorType>()) { | |||
| auto tensor_ptr = type_ptr->cast<CSRTensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| TypePtr elem = tensor_ptr->element(); | |||
| MS_EXCEPTION_IF_NULL(elem); | |||
| return elem->type_id(); | |||
| } | |||
| return type_ptr->type_id(); | |||
| } | |||
| @@ -38,6 +38,8 @@ using Tensor = mindspore::tensor::Tensor; | |||
| using TensorPtr = mindspore::tensor::TensorPtr; | |||
| using MetaTensor = mindspore::tensor::MetaTensor; | |||
| using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; | |||
| using CSRTensor = mindspore::tensor::CSRTensor; | |||
| using CSRTensorPtr = mindspore::tensor::CSRTensorPtr; | |||
| using InstanceCheckFunc = std::function<bool(const py::object &)>; | |||
| using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>; | |||
| @@ -486,6 +488,7 @@ std::vector<DataConverterPtr> GetDataConverters() { | |||
| // Convert data by python object type. | |||
| std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>), | |||
| std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>), | |||
| std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>), | |||
| std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple), | |||
| std::make_shared<ByTypeDataConverter<py::list>>(ConvertList), | |||
| std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>), | |||
| @@ -95,6 +95,7 @@ namespace mindspore { | |||
| namespace pipeline { | |||
| using Tensor = mindspore::tensor::Tensor; | |||
| using MetaTensor = mindspore::tensor::MetaTensor; | |||
| using CSRTensor = mindspore::tensor::CSRTensor; | |||
| using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>; | |||
| using mindspore::abstract::AbstractTensor; | |||
| using mindspore::abstract::AbstractTensorPtr; | |||
| @@ -170,7 +171,8 @@ bool CheckArgValid(const py::handle &arg) { | |||
| } | |||
| return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) || | |||
| py::isinstance<Number>(arg) || (py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__")); | |||
| py::isinstance<Number>(arg) || | |||
| ((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg)) && !py::hasattr(arg, "__parameter__")); | |||
| } | |||
| std::string GetCompileExceptionInfo() { | |||
| @@ -151,10 +151,15 @@ REGISTER_PYBIND_DEFINE( | |||
| TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); | |||
| return data; | |||
| })); | |||
| (void)py::class_<RowTensorType, Type, std::shared_ptr<RowTensorType>>(m_sub, "RowTensorType").def(py::init()); | |||
| (void)py::class_<RowTensorType, Type, std::shared_ptr<RowTensorType>>(m_sub, "RowTensorType") | |||
| .def(py::init()) | |||
| .def_property_readonly("ElementType", &RowTensorType::element, "Get the RowTensorType's element type."); | |||
| (void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType") | |||
| .def(py::init()); | |||
| (void)py::class_<CSRTensorType, Type, std::shared_ptr<CSRTensorType>>(m_sub, "CSRTensorType").def(py::init()); | |||
| .def(py::init()) | |||
| .def_property_readonly("ElementType", &SparseTensorType::element, "Get the SparseTensorType's element type."); | |||
| (void)py::class_<CSRTensorType, Type, std::shared_ptr<CSRTensorType>>(m_sub, "CSRTensorType") | |||
| .def(py::init()) | |||
| .def_property_readonly("ElementType", &CSRTensorType::element, "Get the CSRTensorType's element type."); | |||
| (void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType") | |||
| .def(py::init()); | |||
| (void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function") | |||
| @@ -313,4 +313,10 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) { | |||
| } | |||
| return cnt; | |||
| } | |||
| bool IsCustomCSROP(const AnfNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV}; | |||
| return IsOneOfPrimitiveCNode(cnode, prims); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -92,6 +92,16 @@ const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {{prim::kPrimCS | |||
| // sparse_process.cc | |||
| const mindspore::HashSet<std::string> make_sparse_set = { | |||
| {prim::kPrimMakeCSRTensor->name()}, {prim::kPrimMakeSparseTensor->name()}, {prim::kPrimMakeRowTensor->name()}}; | |||
| // sparse_op_set records all sparse_compute operators, which takes sparsetensor | |||
| // and (possibly) dense tensors, used in backend common optimization pass: | |||
| // sparse_process.cc | |||
| const mindspore::HashSet<std::string> sparse_op_set = {{prim::kPrimSparseTensorDenseMatmul->name()}, | |||
| {prim::kPrimCSRDenseMul->name()}, | |||
| {prim::kPrimCSRReduceSum->name()}, | |||
| {prim::kPrimCSRMV->name()}, | |||
| {prim::kPrimCSRMul->name()}}; | |||
| bool IsCustomCSROP(const AnfNodePtr &cnode); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_ | |||
| @@ -305,7 +305,7 @@ class _MindsporeFunctionExecutor: | |||
| return None | |||
| new_inputs = [] | |||
| for i in args_list: | |||
| if isinstance(i, Tensor): | |||
| if isinstance(i, (Tensor, CSRTensor)): | |||
| new_inputs.append(i) | |||
| elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): | |||
| new_inputs.append(i) | |||
| @@ -155,7 +155,12 @@ AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -33,6 +33,9 @@ constexpr auto kRankSize = "rank_size"; | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| constexpr auto kCSRDenseShape = "dense_shape"; | |||
| constexpr auto kCSRAxis = "axis"; | |||
| constexpr auto kCSRAvgRows = "csr_avg_rows"; | |||
| AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // An object of a subclass of AbstractBase | |||
| @@ -395,6 +398,125 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co | |||
| return sparse_tensor->dense_shape(); | |||
| } | |||
| AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a sparse tensor and a dense tensor. | |||
| constexpr auto kCSRMulInputsNum = 2; | |||
| constexpr auto kCSRMulShapeSize = 2; | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, kCSRMulInputsNum); | |||
| auto sparse = CheckArg<AbstractCSRTensor>(op_name, args_spec_list, 0); | |||
| auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(sparse); | |||
| MS_EXCEPTION_IF_NULL(sparse->shape()); | |||
| MS_EXCEPTION_IF_NULL(sparse->values()); | |||
| MS_EXCEPTION_IF_NULL(sparse->indices()); | |||
| MS_EXCEPTION_IF_NULL(dense); | |||
| auto sparse_shape = sparse->shape()->shape(); | |||
| auto dense_shape = dense->shape()->shape(); | |||
| if (sparse_shape.size() != kCSRMulShapeSize || dense_shape.size() != kCSRMulShapeSize) { | |||
| MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMulShapeSize << "-D inputs!" | |||
| << "but sparse tensor has " << sparse_shape.size() << " dimensions, " | |||
| << "and dense tensor has " << dense_shape.size() << " dimensions, "; | |||
| } | |||
| auto ret = sparse->values()->Broaden(); | |||
| MS_EXCEPTION_IF_NULL(sparse->indices()->shape()); | |||
| auto nnz_vec = sparse->indices()->shape()->shape(); | |||
| int csr_avg_rows = nnz_vec[0] / dense_shape[0]; | |||
| primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); | |||
| primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a sparse tensor and a dense tensor. | |||
| constexpr auto kCSRMVInputsNum = 2; | |||
| constexpr auto kCSRMVShapeSize = 2; | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, kCSRMVInputsNum); | |||
| auto sparse = CheckArg<AbstractCSRTensor>(op_name, args_spec_list, 0); | |||
| auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(sparse); | |||
| MS_EXCEPTION_IF_NULL(sparse->shape()); | |||
| MS_EXCEPTION_IF_NULL(sparse->values()); | |||
| MS_EXCEPTION_IF_NULL(sparse->indices()); | |||
| MS_EXCEPTION_IF_NULL(dense); | |||
| auto sparse_shape = sparse->shape()->shape(); | |||
| auto dense_shape = dense->shape()->shape(); | |||
| if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) { | |||
| MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs!" | |||
| << "but sparse tensor has " << sparse_shape.size() << " dimensions, " | |||
| << "and dense tensor has " << dense_shape.size() << " dimensions, "; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(sparse->values()->element()); | |||
| ShapeVector out_shape = {sparse_shape[0], dense_shape[1]}; | |||
| auto ret = std::make_shared<AbstractTensor>(sparse->values()->element()->BuildType(), out_shape); | |||
| MS_EXCEPTION_IF_NULL(sparse->indices()->shape()); | |||
| auto nnz_vec = sparse->indices()->shape()->shape(); | |||
| int csr_avg_rows = nnz_vec[0] / dense_shape[0]; | |||
| primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); | |||
| primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a sparse tensor and an axis. | |||
| constexpr auto kCSRReduceSumInputsNum = 2; | |||
| constexpr auto kCSRReduceSumShapeSize = 2; | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, kCSRReduceSumInputsNum); | |||
| auto sparse = CheckArg<AbstractCSRTensor>(op_name, args_spec_list, 0); | |||
| auto axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(sparse); | |||
| MS_EXCEPTION_IF_NULL(sparse->shape()); | |||
| MS_EXCEPTION_IF_NULL(sparse->values()); | |||
| MS_EXCEPTION_IF_NULL(sparse->indices()); | |||
| MS_EXCEPTION_IF_NULL(axis); | |||
| auto sparse_shape = sparse->shape()->shape(); | |||
| if (sparse_shape.size() != kCSRReduceSumShapeSize) { | |||
| MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRReduceSumShapeSize << "-D inputs!" | |||
| << "but sparse tensor has " << sparse_shape.size() << " dimensions."; | |||
| } | |||
| ShapeVector out_shape = sparse_shape; | |||
| MS_EXCEPTION_IF_NULL(axis->BuildValue()); | |||
| if (axis->BuildValue()->isa<Int32Imm>() || axis->BuildValue()->isa<Int64Imm>()) { | |||
| int64_t axis_value = GetValue<int64_t>(axis->BuildValue()); | |||
| int64_t dim = static_cast<int64_t>(sparse_shape.size()); | |||
| if (axis_value < -dim || axis_value >= dim) { | |||
| MS_LOG(EXCEPTION) << "axis should be in [" << -dim << ", " << dim << "). But got axis = " << axis_value; | |||
| } | |||
| if (axis_value >= -dim && axis_value < 0) { | |||
| axis_value += dim; | |||
| } | |||
| out_shape[LongToSize(axis_value)] = 1; | |||
| primitive->set_attr(kCSRAxis, MakeValue(axis_value)); | |||
| } else { | |||
| MS_EXCEPTION(ValueError) << "Currently, only support Integer axis."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(sparse->values()->element()); | |||
| auto ret = std::make_shared<AbstractTensor>(sparse->values()->element()->BuildType(), out_shape); | |||
| MS_EXCEPTION_IF_NULL(sparse->indices()->shape()); | |||
| auto nnz_vec = sparse->indices()->shape()->shape(); | |||
| int csr_avg_rows = nnz_vec[0] / sparse_shape[0]; | |||
| primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); | |||
| primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three tensors and a tuple. | |||
| @@ -228,6 +228,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimCSRTensorGetIndptr, R{InferImplCSRTensorGetIndptr, nullptr, true}}, | |||
| {prim::kPrimCSRTensorGetIndices, R{InferImplCSRTensorGetIndices, nullptr, true}}, | |||
| {prim::kPrimCSRTensorGetDenseShape, R{InferImplCSRTensorGetDenseShape, nullptr, true}}, | |||
| {prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}}, | |||
| {prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}}, | |||
| {prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}}, | |||
| // Comm Ops | |||
| {prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}}, | |||
| {prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}}, | |||
| @@ -494,6 +494,13 @@ inline const PrimitivePtr kPrimCSRTensorGetIndptr = std::make_shared<Primitive>( | |||
| inline const PrimitivePtr kPrimCSRTensorGetIndices = std::make_shared<Primitive>("CSRTensorGetIndices"); | |||
| inline const PrimitivePtr kPrimCSRTensorGetDenseShape = std::make_shared<Primitive>("CSRTensorGetDenseShape"); | |||
| // Sparse ops | |||
| inline const PrimitivePtr kPrimSparseTensorDenseMatmul = std::make_shared<Primitive>("SparseTensorDenseMatmul"); | |||
| inline const PrimitivePtr kPrimCSRDenseMul = std::make_shared<Primitive>("CSRDenseMul"); | |||
| inline const PrimitivePtr kPrimCSRReduceSum = std::make_shared<Primitive>("CSRReduceSum"); | |||
| inline const PrimitivePtr kPrimCSRMV = std::make_shared<Primitive>("CSRMV"); | |||
| inline const PrimitivePtr kPrimCSRMul = std::make_shared<Primitive>("CSRMul"); | |||
| // TensorList | |||
| inline const PrimitivePtr kPrimTensorListFromTensor = std::make_shared<Primitive>("TensorListFromTensor"); | |||
| inline const PrimitivePtr kPrimTensorListReserve = std::make_shared<Primitive>("TensorListReserve"); | |||
| @@ -436,6 +436,7 @@ inline const TypePtr kKeyword = std::make_shared<Keyword>(); | |||
| inline const TypePtr kTensorType = std::make_shared<TensorType>(); | |||
| inline const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16)); | |||
| inline const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32)); | |||
| inline const TypePtr kCSRTensorType = std::make_shared<CSRTensorType>(); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_DTYPE_H_ | |||
| @@ -61,20 +61,44 @@ bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) c | |||
| } | |||
| TypePtr TypeIdToType(TypeId id) { | |||
| static mindspore::HashMap<TypeId, TypePtr> type_id_to_type = { | |||
| {kNumberTypeFloat16, kFloat16}, {kNumberTypeFloat, kFloat32}, {kNumberTypeFloat32, kFloat32}, | |||
| {kNumberTypeFloat64, kFloat64}, {kNumberTypeComplex64, kComplex64}, {kNumberTypeInt8, kInt8}, | |||
| {kNumberTypeInt16, kInt16}, {kNumberTypeInt32, kInt32}, {kNumberTypeInt, kInt32}, | |||
| {kNumberTypeInt64, kInt64}, {kNumberTypeUInt8, kUInt8}, {kNumberTypeUInt16, kUInt16}, | |||
| {kNumberTypeUInt32, kUInt32}, {kNumberTypeUInt64, kUInt64}, {kNumberTypeBool, kBool}, | |||
| {kNumberTypeComplex64, kComplex64}, {kNumberTypeComplex128, kComplex128}, {kMetaTypeExternal, kTypeExternal}, | |||
| {kMetaTypeAnything, kAnyType}, {kMetaTypeNone, kTypeNone}, {kMetaTypeNull, kTypeNull}, | |||
| {kMetaTypeEllipsis, kTypeEllipsis}, {kObjectTypeEnvType, kTypeEnv}, {kObjectTypeRefKey, kRefKeyType}, | |||
| {kObjectTypeRef, kRefType}, {kMetaTypeTypeType, kTypeType}, {kObjectTypeString, kString}, | |||
| {kObjectTypeList, kList}, {kObjectTypeTuple, kTuple}, {kObjectTypeDictionary, kDict}, | |||
| {kObjectTypeSlice, kSlice}, {kObjectTypeKeyword, kKeyword}, {kObjectTypeTensorType, kTensorType}, | |||
| {kObjectTypeUMonad, kUMonadType}, {kObjectTypeIOMonad, kIOMonadType}, {kTypeUnknown, kTypeNone}, | |||
| {kMetaTypeProblem, kTypeNone}}; | |||
| static mindspore::HashMap<TypeId, TypePtr> type_id_to_type = {{kNumberTypeFloat16, kFloat16}, | |||
| {kNumberTypeFloat, kFloat32}, | |||
| {kNumberTypeFloat32, kFloat32}, | |||
| {kNumberTypeFloat64, kFloat64}, | |||
| {kNumberTypeComplex64, kComplex64}, | |||
| {kNumberTypeInt8, kInt8}, | |||
| {kNumberTypeInt16, kInt16}, | |||
| {kNumberTypeInt32, kInt32}, | |||
| {kNumberTypeInt, kInt32}, | |||
| {kNumberTypeInt64, kInt64}, | |||
| {kNumberTypeUInt8, kUInt8}, | |||
| {kNumberTypeUInt16, kUInt16}, | |||
| {kNumberTypeUInt32, kUInt32}, | |||
| {kNumberTypeUInt64, kUInt64}, | |||
| {kNumberTypeBool, kBool}, | |||
| {kNumberTypeComplex64, kComplex64}, | |||
| {kNumberTypeComplex128, kComplex128}, | |||
| {kMetaTypeExternal, kTypeExternal}, | |||
| {kMetaTypeAnything, kAnyType}, | |||
| {kMetaTypeNone, kTypeNone}, | |||
| {kMetaTypeNull, kTypeNull}, | |||
| {kMetaTypeEllipsis, kTypeEllipsis}, | |||
| {kObjectTypeEnvType, kTypeEnv}, | |||
| {kObjectTypeRefKey, kRefKeyType}, | |||
| {kObjectTypeRef, kRefType}, | |||
| {kMetaTypeTypeType, kTypeType}, | |||
| {kObjectTypeString, kString}, | |||
| {kObjectTypeList, kList}, | |||
| {kObjectTypeTuple, kTuple}, | |||
| {kObjectTypeDictionary, kDict}, | |||
| {kObjectTypeSlice, kSlice}, | |||
| {kObjectTypeKeyword, kKeyword}, | |||
| {kObjectTypeTensorType, kTensorType}, | |||
| {kObjectTypeUMonad, kUMonadType}, | |||
| {kObjectTypeIOMonad, kIOMonadType}, | |||
| {kTypeUnknown, kTypeNone}, | |||
| {kMetaTypeProblem, kTypeNone}, | |||
| {kObjectTypeCSRTensorType, kCSRTensorType}}; | |||
| const auto &it = type_id_to_type.find(id); | |||
| if (it == type_id_to_type.end()) { | |||
| MS_LOG(EXCEPTION) << "Not support the type: " << id; | |||
| @@ -704,6 +704,17 @@ abstract::AbstractBasePtr CSRTensor::ToAbstract() { | |||
| MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << "."; | |||
| } | |||
| auto abs_csr_tensor = std::make_shared<abstract::AbstractCSRTensor>(dtype, shape_); | |||
| abs_csr_tensor->set_indptr(indptr_->ToAbstract()->cast<abstract::AbstractTensorPtr>()); | |||
| abs_csr_tensor->set_indices(indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>()); | |||
| abs_csr_tensor->set_values(values_->ToAbstract()->cast<abstract::AbstractTensorPtr>()); | |||
| std::vector<abstract::AbstractBasePtr> abstract_shape; | |||
| std::transform( | |||
| shape_.begin(), shape_.end(), std::back_inserter(abstract_shape), | |||
| [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); }); | |||
| abs_csr_tensor->set_dense_shape(std::make_shared<abstract::AbstractTuple>(abstract_shape)); | |||
| return abs_csr_tensor; | |||
| } | |||
| } // namespace tensor | |||
| @@ -31,7 +31,7 @@ from .._checkparam import Validator | |||
| from ..common import dtype as mstype | |||
| from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor | |||
| from ..common.parameter import Parameter, ParameterTuple | |||
| from ..common.tensor import Tensor | |||
| from ..common.tensor import Tensor, CSRTensor | |||
| from ..ops.operations import HookBackward, Cast | |||
| from ..ops.primitive import Primitive | |||
| from ..parallel._tensor import _load_tensor_by_layout | |||
| @@ -808,6 +808,8 @@ class Cell(Cell_): | |||
| if i.has_init: | |||
| i.init_data() | |||
| new_inputs.append(i) | |||
| elif isinstance(i, CSRTensor): | |||
| new_inputs.append(i) | |||
| elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): | |||
| new_inputs.append(i) | |||
| elif hasattr(self, "enable_tuple_broaden") and self.enable_tuple_broaden and isinstance(i, tuple) and\ | |||
| @@ -22,4 +22,7 @@ from .logical_or import _logical_or_akg | |||
| from .mean import _simple_mean_akg | |||
| from .mean_grad import _simple_mean_grad_akg | |||
| from .notequal import _notequal_akg | |||
| from .csr_reduce_sum import _csr_reduce_sum_akg | |||
| from .csr_mv import _csr_mv_akg | |||
| from .csr_mul import _csr_mul_akg | |||
| # Please insert op register in lexicographical order of the filename. | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """CSRMul op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||
| csr_mul_op_info = AkgGpuRegOp("CSRMul") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "indptr") \ | |||
| .input(1, "indices") \ | |||
| .input(2, "values") \ | |||
| .input(4, "dense_tensor") \ | |||
| .output(0, "output0") \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \ | |||
| DataType.F32_Default, \ | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \ | |||
| DataType.F32_Default, \ | |||
| DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(csr_mul_op_info) | |||
| def _csr_mul_akg(): | |||
| """CSRMul AutoDiff register""" | |||
| return | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """CSRMV op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||
| csr_mv_op_info = AkgGpuRegOp("CSRMV") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "indptr") \ | |||
| .input(1, "indices") \ | |||
| .input(2, "values") \ | |||
| .input(4, "dense_tensor") \ | |||
| .output(0, "output") \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \ | |||
| DataType.F32_Default, \ | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \ | |||
| DataType.F32_Default, \ | |||
| DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(csr_mv_op_info) | |||
| def _csr_mv_akg(): | |||
| """CSRMV AutoDiff register""" | |||
| return | |||
| @@ -0,0 +1,33 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """CSRReduceSum op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||
| csr_reduce_sum_op_info = AkgGpuRegOp("CSRReduceSum") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "indptr") \ | |||
| .input(1, "indices") \ | |||
| .input(2, "values") \ | |||
| .output(0, "output") \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \ | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \ | |||
| DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(csr_reduce_sum_op_info) | |||
| def _csr_reduce_sum_akg(): | |||
| """CSRReduceSum AutoDiff register""" | |||
| return | |||
| @@ -18,7 +18,7 @@ | |||
| from . import _compile_utils as utils | |||
| from ...composite import base | |||
| from ... import functional as F | |||
| from ....common import CSRTensor | |||
| mul = base.MultitypeFuncGraph("mul", True) | |||
| """ | |||
| @@ -27,6 +27,30 @@ using ".register" decorator. | |||
| """ | |||
| @mul.register("CSRTensor", "Tensor") | |||
| def _mul_csrtensor_tensor(x, y): | |||
| """ | |||
| Returns x * y where x is CSRTensor and y is Tensor. | |||
| Outputs: | |||
| CSRTensor, equal to x * y. | |||
| """ | |||
| data = F.csr_mul(x, y) | |||
| return CSRTensor(x.indptr, x.indices, data, x.shape) | |||
| @mul.register("Tensor", "CSRTensor") | |||
| def _mul_tensor_csrtensor(x, y): | |||
| """ | |||
| Returns x * y where x is Tensor and y is CSRTensor. | |||
| Outputs: | |||
| CSRTensor, equal to x * y. | |||
| """ | |||
| data = F.csr_mul(y, x) | |||
| return CSRTensor(y.indptr, y.indices, data, y.shape) | |||
| @mul.register("Number", "Number") | |||
| def _mul_scalar(x, y): | |||
| """ | |||
| @@ -27,6 +27,7 @@ from mindspore.ops.primitive import constexpr | |||
| from .primitive import Primitive | |||
| from . import operations as P | |||
| from .operations import _grad_ops | |||
| from .operations import _csr_ops | |||
| from .composite import _Grad | |||
| from .._c_expression import security | |||
| @@ -146,6 +147,10 @@ tensor_scatter_update = P.TensorScatterUpdate() | |||
| scatter_nd_update = P.ScatterNdUpdate() | |||
| stack = P.Stack() | |||
| csr_mul = _csr_ops.CSRMul() | |||
| csr_mv = _csr_ops.CSRMV() | |||
| csr_reduce_sum = _csr_ops.CSRReduceSum() | |||
| def pack(x): | |||
| """Call stack in this pack function.""" | |||
| @@ -0,0 +1,153 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """csr_ops""" | |||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | |||
| class CSRReduceSum(PrimitiveWithInfer): | |||
| """ | |||
| Reduces a dimension of a CSRTensor by summing all elements in the dimension. | |||
| Inputs: | |||
| - **sparse_tensor** (CSRTensor) - A CSRTensor. | |||
| - **axis** (int) - The dimensions to reduce. | |||
| Outputs: | |||
| Tensor, the dtype is the same as `sparse_tensor.values`. | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import Tensor, CSRTensor, ops | |||
| >>> from mindspore import dtype as mstype | |||
| >>> class Net(nn.Cell): | |||
| ... def __init__(self): | |||
| ... super(Net, self).__init__() | |||
| ... self.op = ops.CSRReduceSum() | |||
| ... | |||
| ... def construct(self, indptr, indices, values, dense_shape, axis): | |||
| ... csr_tensor = CSRTensor(indptr, indices, values, dense_shape) | |||
| ... return self.op(csr_tensor, axis) | |||
| >>> indptr = Tensor([0, 1, 2]) | |||
| >>> indices = Tensor([0, 1]) | |||
| >>> values = Tensor([2, 1], dtype=mstype.float32) | |||
| >>> dense_shape = (2, 4) | |||
| >>> out = Net()(indptr, indices, values, dense_shape, 1) | |||
| >>> print(out) | |||
| [[2.] | |||
| [1.]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize CSRReduceSum""" | |||
| self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'axis'], | |||
| outputs=['output']) | |||
| class CSRMV(PrimitiveWithInfer): | |||
| """ | |||
| Sparse matrix-vector multiplication. | |||
| Inputs: | |||
| - **sparse_tensor** (CSRTensor) - A CSRTensor. | |||
| - **dense_tensor** (Tensor) - A dense Tensor. | |||
| Outputs: | |||
| Tensor. | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import Tensor, CSRTensor, ops | |||
| >>> from mindspore import dtype as mstype | |||
| >>> class Net(nn.Cell): | |||
| ... def __init__(self): | |||
| ... super(Net, self).__init__() | |||
| ... self.op = ops.CSRMV() | |||
| ... | |||
| ... def construct(self, indptr, indices, values, dense_shape, dense): | |||
| ... csr_tensor = CSRTensor(indptr, indices, values, dense_shape) | |||
| ... return self.op(csr_tensor, dense) | |||
| >>> indptr = Tensor([0, 1, 2]) | |||
| >>> indices = Tensor([0, 1]) | |||
| >>> values = Tensor([2, 1], dtype=mstype.float32) | |||
| >>> dense_shape = (2, 4) | |||
| >>> dense = Tensor([[1], [1], [1], [1]], dtype=mstype.float32) | |||
| >>> out = Net()(indptr, indices, values, dense_shape, dense) | |||
| >>> print(out) | |||
| [[2.] | |||
| [1.]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize CSRMV""" | |||
| self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'], | |||
| outputs=['output']) | |||
| class CSRMul(PrimitiveWithInfer): | |||
| """ | |||
| Elemwise multiplication on a CSRTensor and a dense tensor. | |||
| Note: | |||
| The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`. | |||
| If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor. | |||
| Inputs: | |||
| - **sparse_tensor** (CSRTensor) - A CSRTensor. | |||
| - **dense_tensor** (Tensor) - A Tensor. | |||
| Outputs: | |||
| Tensor, the dtype and shape is the same as `sparse_tensor.values`. | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import Tensor, CSRTensor, ops | |||
| >>> from mindspore import dtype as mstype | |||
| >>> class Net(nn.Cell): | |||
| ... def __init__(self): | |||
| ... super(Net, self).__init__() | |||
| ... self.op = ops.CSRMul() | |||
| ... | |||
| ... def construct(self, indptr, indices, values, dense_shape, dense): | |||
| ... csr_tensor = CSRTensor(indptr, indices, values, dense_shape) | |||
| ... return self.op(csr_tensor, dense) | |||
| >>> indptr = Tensor([0, 1, 2]) | |||
| >>> indices = Tensor([0, 1]) | |||
| >>> values = Tensor([2, 1], dtype=mstype.float32) | |||
| >>> dense_shape = (2, 4) | |||
| >>> dense = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32) | |||
| >>> out = Net()(indptr, indices, values, dense_shape, dense) | |||
| >>> print(out) | |||
| [2. 1.] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize CSRMul""" | |||
| self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'], | |||
| outputs=['output']) | |||
| @@ -17,9 +17,9 @@ | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import Tensor, CSRTensor, ms_function | |||
| from mindspore import Tensor, CSRTensor, ms_function, nn, context | |||
| from mindspore.ops.operations import _csr_ops | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore import nn, context | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -200,3 +200,62 @@ def test_csr_tensor_in_while_cpu(): | |||
| assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0) | |||
| assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0) | |||
| assert shape == out.shape | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_csr_ops(): | |||
| """ | |||
| Feature: Test CSR-related Ops. | |||
| Description: Test CSRReduceSum, CSRMul, CSRMV. | |||
| Expectation: Success. | |||
| """ | |||
| class CSRReduceSumNet(nn.Cell): | |||
| def __init__(self): | |||
| super(CSRReduceSumNet, self).__init__() | |||
| self.op = _csr_ops.CSRReduceSum() | |||
| def construct(self, indptr, indices, values, dense_shape, axis): | |||
| csr_tensor = CSRTensor(indptr, indices, values, dense_shape) | |||
| return self.op(csr_tensor, axis) | |||
| class CSRMulNet(nn.Cell): | |||
| def __init__(self): | |||
| super(CSRMulNet, self).__init__() | |||
| self.op = _csr_ops.CSRMul() | |||
| def construct(self, indptr, indices, values, dense_shape, dense): | |||
| csr_tensor = CSRTensor(indptr, indices, values, dense_shape) | |||
| return self.op(csr_tensor, dense) | |||
| class CSRMVNet(nn.Cell): | |||
| def __init__(self): | |||
| super(CSRMVNet, self).__init__() | |||
| self.op = _csr_ops.CSRMV() | |||
| def construct(self, indptr, indices, values, dense_shape, dense): | |||
| csr_tensor = CSRTensor(indptr, indices, values, dense_shape) | |||
| return self.op(csr_tensor, dense) | |||
| indptr = Tensor([0, 1, 2]) | |||
| indices = Tensor([0, 1]) | |||
| values = Tensor([2, 1], dtype=mstype.float32) | |||
| dense_shape = (2, 4) | |||
| dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32) | |||
| dense_vector = Tensor([[1.], [1], [1], [1]], dtype=mstype.float32) | |||
| net1 = CSRReduceSumNet() | |||
| out1 = net1(indptr, indices, values, dense_shape, 1) | |||
| expect1 = np.array([[2.], [1.]], dtype=np.float32) | |||
| assert np.allclose(out1.asnumpy(), expect1) | |||
| net2 = CSRMulNet() | |||
| out2 = net2(indptr, indices, values, dense_shape, dense_tensor) | |||
| expect2 = np.array([2., 1.], dtype=np.float32) | |||
| assert np.allclose(out2.asnumpy(), expect2) | |||
| net3 = CSRMVNet() | |||
| out3 = net3(indptr, indices, values, dense_shape, dense_vector) | |||
| expect3 = np.array([[2.], [1.]], dtype=np.float32) | |||
| assert np.allclose(out3.asnumpy(), expect3) | |||